Compare commits

...

57 Commits

Author SHA1 Message Date
Nuno Cruces
59f79e8e74 Optimize calls. 2023-05-02 01:08:04 +01:00
dependabot[bot]
40457721d7 Bump github.com/tetratelabs/wazero from 1.0.3 to 1.1.0 (#11)
Bumps [github.com/tetratelabs/wazero](https://github.com/tetratelabs/wazero) from 1.0.3 to 1.1.0.
- [Release notes](https://github.com/tetratelabs/wazero/releases)
- [Commits](https://github.com/tetratelabs/wazero/compare/v1.0.3...v1.1.0)

---
updated-dependencies:
- dependency-name: github.com/tetratelabs/wazero
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-05-02 01:03:42 +01:00
Nuno Cruces
18eeb85783 Improve mock. 2023-04-28 13:50:50 +01:00
Nuno Cruces
b36536979b Fix reopen. 2023-04-28 13:50:32 +01:00
Nuno Cruces
a6226c3b31 wazero v1.0.3. 2023-04-22 10:06:48 +01:00
Nuno Cruces
bed2ee7674 Refactor. 2023-04-22 00:15:44 +01:00
Nuno Cruces
7e6d178122 Fix. 2023-04-21 13:33:24 +01:00
Nuno Cruces
f360c77a78 Optimize blobs. (#10) 2023-04-21 13:31:45 +01:00
Nuno Cruces
759b11a05d wazero 1.0.2. 2023-04-18 23:33:56 +01:00
Nuno Cruces
93ce586139 Optimize time. 2023-04-18 01:00:59 +01:00
Nuno Cruces
2e5082c616 Query pragmas at startup. 2023-04-17 00:29:20 +01:00
Nuno Cruces
34acc28af8 Fix CI. 2023-04-14 15:48:20 +01:00
Nuno Cruces
c1a640f7d8 Build using wasi-sdk. 2023-04-14 15:31:17 +01:00
Nuno Cruces
005b15610a Memory optimizations. 2023-04-11 15:33:38 +01:00
Nuno Cruces
23ee4ccb0b Refactor. 2023-04-10 19:55:44 +01:00
Nuno Cruces
3a8cfd036d Dependencies. 2023-04-10 14:24:06 +01:00
Nuno Cruces
c38382fd8e Refactor. 2023-03-31 14:33:24 +01:00
Nuno Cruces
8509e0b6c8 Test coverage. 2023-03-31 13:42:31 +01:00
Nuno Cruces
9c07e57252 Refactor. 2023-03-29 15:06:22 +01:00
Nuno Cruces
80039385d3 Read only files. 2023-03-25 11:46:13 +00:00
Nuno Cruces
89f4327b2b Sync journal directories. 2023-03-25 11:16:51 +00:00
Nuno Cruces
37a3ff37e8 wazero 1.0. 2023-03-24 21:17:30 +00:00
Nuno Cruces
d880d6842c Refactor VFS. 2023-03-23 13:29:26 +00:00
Nuno Cruces
bef46e7954 Locking improvements (windows). 2023-03-23 12:40:55 +00:00
Nuno Cruces
4e72b4d117 Locking fix. 2023-03-23 11:26:19 +00:00
Nuno Cruces
3b08d02a83 Lock refactoring. 2023-03-23 01:55:54 +00:00
Nuno Cruces
b19c12c4c7 SQLite 3.41.2, prefer speed over size. 2023-03-23 00:44:43 +00:00
Nuno Cruces
859a21ef4e CI improvements. 2023-03-22 12:08:33 +00:00
Nuno Cruces
8ff0ee752f Use flock. 2023-03-22 03:15:54 +00:00
Nuno Cruces
589ad86f76 Extensions. 2023-03-21 00:13:12 +00:00
Nuno Cruces
1a3a1be1f6 Fix test. 2023-03-20 14:26:25 +00:00
Nuno Cruces
222c217bc8 Scripts. 2023-03-20 13:06:31 +00:00
Nuno Cruces
c1dc716391 VFS performance. 2023-03-20 11:02:34 +00:00
Nuno Cruces
71e1e5a8ee Avoid some copies. 2023-03-20 02:16:42 +00:00
Nuno Cruces
e4efb20c71 Generate coverage chart. 2023-03-18 03:51:05 +00:00
Nuno Cruces
2c9459d907 Add SQLite speedtest1. 2023-03-18 03:03:11 +00:00
Nuno Cruces
d0875e5fab Lock timeouts. 2023-03-18 01:13:31 +00:00
Nuno Cruces
15dec13f15 FCNTL_SIZE_HINT, refactor. 2023-03-17 17:13:03 +00:00
Nuno Cruces
f38e36109a FCNTL_HAS_MOVED. 2023-03-17 14:11:09 +00:00
Nuno Cruces
4cb65ccbd9 xFileControl, xDeviceCharacteristics, PSOW. 2023-03-17 13:39:19 +00:00
Nuno Cruces
f789c2fb8b OPEN_NOFOLLOW. 2023-03-16 12:27:44 +00:00
Nuno Cruces
c6a2617dfc Locking fixes. 2023-03-16 02:52:22 +00:00
Nuno Cruces
6fc0afcd12 Towards lock timeouts. 2023-03-15 13:58:16 +00:00
Nuno Cruces
77088962f5 SQLite 3.41.1. 2023-03-15 13:29:09 +00:00
Nuno Cruces
71da34861b Fix time collation. 2023-03-13 04:19:58 +00:00
Nuno Cruces
56e8281bdb Time collation tests. 2023-03-10 16:42:20 +00:00
Nuno Cruces
f61d430e65 Documentation. 2023-03-10 16:26:19 +00:00
Nuno Cruces
dbaed53b9a Sync and delete improvements. 2023-03-10 14:17:02 +00:00
Nuno Cruces
8b1bfd04e3 Simplify windows hacks. 2023-03-10 10:43:02 +00:00
Nuno Cruces
11c1687146 Time collation. 2023-03-09 14:42:29 +00:00
Nuno Cruces
94c43a8685 Use access syscall. 2023-03-09 01:59:46 +00:00
Nuno Cruces
a25159a070 Fix sharing violation. 2023-03-09 01:23:52 +00:00
Nuno Cruces
e007e9b060 Windows fixes. 2023-03-08 20:10:46 +00:00
Nuno Cruces
66a730893f Fix readonly transaction rollback. 2023-03-08 18:07:21 +00:00
Nuno Cruces
926adeb3f5 Remove MustPrepare. 2023-03-08 17:39:41 +00:00
Nuno Cruces
677f51bec1 Savepoint API. 2023-03-08 17:39:23 +00:00
Nuno Cruces
5d6f92b733 Documentation, tests, tweaks. 2023-03-08 13:29:33 +00:00
89 changed files with 3476 additions and 1918 deletions

View File

@@ -18,11 +18,27 @@ jobs:
with:
lfs: 'true'
- name: Set up Go
uses: actions/setup-go@v3
- name: Set up
uses: actions/setup-go@v4
with:
go-version: stable
cache: true
- name: Format
run: gofmt -s -w . && git diff --exit-code
if: matrix.os != 'windows-latest'
- name: Tidy
run: go mod tidy && git diff --exit-code
- name: Download
run: go mod download
- name: Verify
run: go mod verify
- name: Vet
run: go vet ./...
continue-on-error: true
- name: Build
run: go build -v ./...
@@ -30,12 +46,15 @@ jobs:
- name: Test
run: go test -v ./...
- name: Test data races
run: go test -v -race ./...
if: matrix.os == 'ubuntu-latest'
- name: Test BSD locks
run: go test -v -tags sqlite3_bsd ./...
if: matrix.os == 'macos-latest'
- name: Update coverage report
uses: ncruces/go-coverage-report@main
- name: Coverage report
uses: ncruces/go-coverage-report@v0
with:
chart: 'true'
amend: 'true'
if: |
matrix.os == 'ubuntu-latest' &&
github.event_name == 'push'

View File

@@ -18,6 +18,10 @@ embeds a build of SQLite into your application.
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
with a [pure Go](internal/vfs/) implementation.
This has numerous benefits, but also comes with some drawbacks.
#### Write-Ahead Logging
Because WASM does not support shared memory,
@@ -30,44 +34,47 @@ For non-WAL databases, `NORMAL` locking mode can be activated with
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
the `database/sql` driver defaults to `NORMAL` locking mode,
and WAL databases are not supported.
the `database/sql` driver defaults to `NORMAL` locking mode.
To open WAL databases, or use `EXCLUSIVE` locking mode,
disable connection pooling by calling
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
#### Open File Description Locks
#### POSIX Advisory Locks
On Unix, this module uses [OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
POSIX advisory locks, which SQLite uses, are
[broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
On Linux, macOS and illumos, this module uses
[OFD locks](https://www.gnu.org/software/libc/manual/html_node/Open-File-Description-Locks.html)
to synchronize access to database files.
OFD locks are fully compatible with process-associated POSIX advisory locks.
POSIX advisory locks, which SQLite uses, are [broken by design](https://www.sqlite.org/src/artifact/90c4fa?ln=1073-1161).
OFD locks are fully compatible with process-associated POSIX advisory locks,
and are supported on Linux, macOS and illumos.
As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.org/uri.html).
On BSD Unixes, this module uses
[BSD locks](https://man.freebsd.org/cgi/man.cgi?query=flock&sektion=2).
BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
#### Testing
The pure Go VFS is stress tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
[speedtest1](https://github.com/sqlite/sqlite/blob/master/test/speedtest1.c).
### Roadmap
- [x] build SQLite using `zig cc --target=wasm32-wasi`
- [x] `:memory:` databases
- [x] port [`test_demovfs.c`](https://www.sqlite.org/src/doc/trunk/src/test_demovfs.c) to Go
- branch [`wasi`](https://github.com/ncruces/go-sqlite3/tree/wasi) uses `test_demovfs.c` directly
- [x] design a nice API, enough for simple use cases
- [x] provide a simple `database/sql` driver
- [x] file locking, compatible with SQLite on macOS/Linux/Windows
- [ ] advanced SQLite features
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [ ] snapshots
- [ ] session extension
- [ ] resumable bulk update
- [ ] shared-cache mode
- [ ] unlock-notify
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] in-memory VFS
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] in-memory VFS, wrapping a [`bytes.Buffer`](https://pkg.go.dev/bytes#Buffer)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom VFS API
### Alternatives
- [`modernc.org/sqlite`](https://pkg.go.dev/modernc.org/sqlite)

View File

@@ -11,7 +11,7 @@ type Backup struct {
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
//
// Backup calls [Conn.Open] to open the SQLite database file dstURI,
// Backup calls [Open] to open the SQLite database file dstURI,
// and blocks until the entire backup is complete.
// Use [Conn.BackupInit] for incremental backup.
//
@@ -28,12 +28,12 @@ func (src *Conn) Backup(srcDB, dstURI string) error {
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
//
// Restore calls [Conn.Open] to open the SQLite database file srcURI,
// Restore calls [Open] to open the SQLite database file srcURI,
// and blocks until the entire restore is complete.
//
// https://www.sqlite.org/backup.html
func (dst *Conn) Restore(dstDB, srcURI string) error {
src, err := dst.openDB(srcURI, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
src, err := dst.openDB(srcURI, OPEN_READONLY|OPEN_URI)
if err != nil {
return err
}
@@ -48,7 +48,7 @@ func (dst *Conn) Restore(dstDB, srcURI string) error {
// BackupInit initializes a backup operation to copy the content of one database into another.
//
// BackupInit calls [Conn.Open] to open the SQLite database file dstURI,
// BackupInit calls [Open] to open the SQLite database file dstURI,
// then initializes a backup that copies the contents of srcDB on the src connection
// to the "main" database in dstURI.
//

131
blob.go
View File

@@ -1,6 +1,10 @@
package sqlite3
import "io"
import (
"io"
"github.com/ncruces/go-sqlite3/internal/util"
)
// ZeroBlob represents a zero-filled, length n BLOB
// that can be used as an argument to
@@ -14,9 +18,9 @@ type ZeroBlob int64
// https://www.sqlite.org/c3ref/blob.html
type Blob struct {
c *Conn
handle uint32
bytes int64
offset int64
handle uint32
}
var _ io.ReadWriteSeeker = &Blob{}
@@ -25,6 +29,7 @@ var _ io.ReadWriteSeeker = &Blob{}
//
// https://www.sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.reset()
blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db)
@@ -45,7 +50,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
}
blob := Blob{c: c}
blob.handle = c.mem.readUint32(blobPtr)
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
return &blob, nil
}
@@ -81,14 +86,14 @@ func (b *Blob) Read(p []byte) (n int, err error) {
return 0, io.EOF
}
want := int64(len(p))
avail := b.bytes - b.offset
want := int64(len(p))
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
defer b.c.arena.reset()
ptr := b.c.arena.new(uint64(want))
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
@@ -96,30 +101,68 @@ func (b *Blob) Read(p []byte) (n int, err error) {
if err != nil {
return 0, err
}
mem := b.c.mem.view(ptr, uint64(want))
copy(p, mem)
b.offset += want
if b.offset >= b.bytes {
err = io.EOF
}
copy(p, util.View(b.c.mod, ptr, uint64(want)))
return int(want), err
}
// WriteTo implements the [io.WriterTo] interface.
//
// https://www.sqlite.org/c3ref/blob_read.html
func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
if b.offset >= b.bytes {
return 0, nil
}
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
for want > 0 {
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return n, err
}
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := w.Write(mem[:want])
b.offset += int64(m)
n += int64(m)
if err != nil {
return n, err
}
if int64(m) != want {
return n, io.ErrShortWrite
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
}
return n, nil
}
// Write implements the [io.Writer] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) Write(p []byte) (n int, err error) {
offset := b.offset
if offset > b.bytes {
offset = b.bytes
}
ptr := b.c.newBytes(p)
defer b.c.free(ptr)
defer b.c.arena.reset()
ptr := b.c.arena.bytes(p)
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(offset))
uint64(ptr), uint64(len(p)), uint64(b.offset))
err = b.c.error(r[0])
if err != nil {
return 0, err
@@ -128,11 +171,57 @@ func (b *Blob) Write(p []byte) (n int, err error) {
return len(p), nil
}
// ReadFrom implements the [io.ReaderFrom] interface.
//
// https://www.sqlite.org/c3ref/blob_write.html
func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
avail := b.bytes - b.offset
want := int64(65536)
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
ptr := b.c.new(uint64(want))
defer b.c.free(ptr)
for {
mem := util.View(b.c.mod, ptr, uint64(want))
m, err := r.Read(mem[:want])
if m > 0 {
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(m), uint64(b.offset))
err := b.c.error(r[0])
if err != nil {
return n, err
}
b.offset += int64(m)
n += int64(m)
}
if err == io.EOF {
return n, nil
}
if err != nil {
return n, err
}
avail = b.bytes - b.offset
if want > avail {
want = avail
}
if want < 1 {
want = 1
}
}
}
// Seek implements the [io.Seeker] interface.
func (b *Blob) Seek(offset int64, whence int) (int64, error) {
switch whence {
default:
return 0, whenceErr
return 0, util.WhenceErr
case io.SeekStart:
break
case io.SeekCurrent:
@@ -141,7 +230,7 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
offset += b.bytes
}
if offset < 0 {
return 0, offsetErr
return 0, util.OffsetErr
}
b.offset = offset
return offset, nil
@@ -151,8 +240,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))[0])
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
b.offset = 0
return b.c.error(r[0])
return err
}

81
conn.go
View File

@@ -3,12 +3,15 @@ package sqlite3
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync/atomic"
"unsafe"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Conn is a database connection handle.
@@ -18,26 +21,31 @@ import (
type Conn struct {
*module
handle uint32
arena arena
interrupt context.Context
waiter chan struct{}
pending *Stmt
arena arena
handle uint32
}
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
}
// OpenFlags opens an SQLite database file as specified by the filename argument.
//
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
//
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
//
// https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE
}
return newConn(filename, flags)
}
@@ -50,7 +58,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
if conn == nil {
mod.close()
} else {
runtime.SetFinalizer(conn, finalizer[Conn](3))
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
}
}()
@@ -68,9 +76,10 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
connPtr := c.arena.new(ptrlen)
namePtr := c.arena.string(filename)
flags |= OPEN_EXRESCODE
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := c.mem.readUint32(connPtr)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r[0], handle); err != nil {
c.closeDB(handle)
return 0, err
@@ -87,13 +96,18 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
}
}
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
c.closeDB(handle)
return 0, fmt.Errorf("sqlite3: invalid _pragma: %w", err)
return 0, err
}
}
return handle, nil
}
@@ -119,6 +133,8 @@ func (c *Conn) Close() error {
}
c.SetInterrupt(context.Background())
c.pending.Close()
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
@@ -143,23 +159,6 @@ func (c *Conn) Exec(sql string) error {
return c.error(r[0])
}
// MustPrepare calls [Conn.Prepare] and panics on error,
// a nil Stmt, or a non-empty tail.
func (c *Conn) MustPrepare(sql string) *Stmt {
s, tail, err := c.PrepareFlags(sql, 0)
if err != nil {
panic(err)
}
if s == nil {
panic(emptyErr)
}
if !emptyStatement(tail) {
s.Close()
panic(tailErr)
}
return s
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
@@ -186,8 +185,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
uint64(stmtPtr), uint64(tailPtr))
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
stmt.handle = util.ReadUint32(c.mod, stmtPtr)
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
@@ -247,26 +246,23 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
<-c.waiter // Wait for it to finish.
c.waiter = nil
}
// Reset the pending statement.
if c.pending != nil {
c.pending.Reset()
}
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx.Done() == nil {
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
if c.pending == nil {
c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
} else {
c.pending.Reset()
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
}
c.pending.Step()
// Don't create the goroutine if we're already interrupted.
// This happens frequently while restoring to a previously interrupted state.
@@ -282,7 +278,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
break
case <-ctx.Done(): // Done was closed.
buf := c.mem.view(c.handle+c.api.interrupt, 4)
buf := util.View(c.mod, c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
// Wait for the next call to SetInterrupt.
<-waiter
@@ -298,7 +294,7 @@ func (c *Conn) checkInterrupt() bool {
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
buf := c.mem.view(c.handle+c.api.interrupt, 4)
buf := util.View(c.mod, c.handle+c.api.interrupt, 4)
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
return true
}
@@ -306,15 +302,18 @@ func (c *Conn) checkInterrupt() bool {
// Pragma executes a PRAGMA statement and returns any results.
//
// https://www.sqlite.org/pragma.html
func (c *Conn) Pragma(str string) []string {
stmt := c.MustPrepare(`PRAGMA ` + str)
func (c *Conn) Pragma(str string) ([]string, error) {
stmt, _, err := c.Prepare(`PRAGMA ` + str)
if err != nil {
return nil, err
}
defer stmt.Close()
var pragmas []string
for stmt.Step() {
pragmas = append(pragmas, stmt.ColumnText(0))
}
return pragmas
return pragmas, stmt.Close()
}
func (c *Conn) error(rc uint64, sql ...string) error {
@@ -333,6 +332,6 @@ type DriverConn interface {
driver.ExecerContext
driver.ConnPrepareContext
Savepoint() (release func(*error))
Savepoint() Savepoint
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -9,8 +9,7 @@ const (
_UTF8 = 1
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_PATHNAME = 512
_MAX_STRING = 512 // Used for short strings: names, error messages…
_MAX_ALLOCATION_SIZE = 0x7ffffeff
@@ -133,6 +132,7 @@ const (
CONSTRAINT_DATATYPE ExtendedErrorCode = xErrorCode(CONSTRAINT) | (12 << 8)
NOTICE_RECOVER_WAL ExtendedErrorCode = xErrorCode(NOTICE) | (1 << 8)
NOTICE_RECOVER_ROLLBACK ExtendedErrorCode = xErrorCode(NOTICE) | (2 << 8)
NOTICE_RBU ExtendedErrorCode = xErrorCode(NOTICE) | (3 << 8)
WARNING_AUTOINDEX ExtendedErrorCode = xErrorCode(WARNING) | (1 << 8)
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
)
@@ -167,14 +167,6 @@ const (
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
)
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
//
// https://www.sqlite.org/c3ref/c_prepare_normalize.html

View File

@@ -35,6 +35,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
@@ -44,12 +45,12 @@ func init() {
type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
c, err := sqlite3.OpenFlags(name, sqlite3.OPEN_READWRITE|sqlite3.OPEN_CREATE|sqlite3.OPEN_URI|sqlite3.OPEN_EXRESCODE)
var c conn
c.conn, err = sqlite3.Open(name)
if err != nil {
return nil, err
}
var txBegin string
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
@@ -57,9 +58,9 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
switch s := query.Get("_txlock"); s {
case "":
txBegin = "BEGIN"
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
txBegin = "BEGIN " + s
c.txBegin = "BEGIN " + s
default:
c.Close()
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
@@ -69,32 +70,52 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
}
if len(pragmas) == 0 {
err := c.Exec(`
PRAGMA busy_timeout=60000;
err := c.conn.Exec(`
PRAGMA locking_mode=normal;
PRAGMA busy_timeout=60000;
`)
if err != nil {
c.Close()
return nil, err
}
c.reusable = true
} else {
s, _, err := c.conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
`)
if err != nil {
c.Close()
return nil, err
}
if s.Step() {
c.reusable = s.ColumnText(0) == "normal"
c.readOnly = s.ColumnRawText(1)[0] // 0 or 1
}
err = s.Close()
if err != nil {
c.Close()
return nil, err
}
}
return conn{
conn: c,
txBegin: txBegin,
}, nil
return c, nil
}
type conn struct {
conn *sqlite3.Conn
txBegin string
txCommit string
conn *sqlite3.Conn
txBegin string
txCommit string
txRollback string
reusable bool
readOnly byte
}
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.Validator = conn{}
_ sqlite3.DriverConn = conn{}
)
@@ -102,27 +123,36 @@ func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
return c.reusable
}
func (c conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
switch opts.Isolation {
default:
return nil, isolationErr
case driver.IsolationLevel(sql.LevelDefault):
case driver.IsolationLevel(sql.LevelSerializable):
}
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
if opts.ReadOnly {
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + c.conn.Pragma("query_only")[0]
txBegin = `
BEGIN deferred;
PRAGMA query_only=on`
c.txCommit = `
ROLLBACK;
PRAGMA query_only=` + string(c.readOnly)
c.txRollback = c.txCommit
}
switch opts.Isolation {
default:
return nil, util.IsolationErr
case
driver.IsolationLevel(sql.LevelDefault),
driver.IsolationLevel(sql.LevelSerializable):
break
}
err := c.conn.Exec(txBegin)
@@ -134,14 +164,14 @@ func (c conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, er
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil {
if err != nil && !c.conn.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(`ROLLBACK`)
return c.conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
@@ -159,7 +189,7 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
if st != nil {
s.Close()
st.Close()
return nil, tailErr
return nil, util.TailErr
}
}
return stmt{s, c.conn}, nil
@@ -189,7 +219,7 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
}, nil
}
func (c conn) Savepoint() (release func(*error)) {
func (c conn) Savepoint() sqlite3.Savepoint {
return c.conn.Savepoint()
}
@@ -288,11 +318,11 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
panic(util.AssertErr())
}
}
if err != nil {
@@ -359,11 +389,10 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = r.stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeTime(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.stmt.ColumnBlob(i, buf)
dest[i] = r.stmt.ColumnRawBlob(i)
case sqlite3.TEXT:
dest[i] = stringOrTime(r.stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
@@ -371,7 +400,7 @@ func (r rows) Next(dest []driver.Value) error {
dest[i] = nil
}
default:
panic(assertErr)
panic(util.AssertErr())
}
}

View File

@@ -1,4 +1,3 @@
// Package driver provides a database/sql driver for SQLite.
package driver
import (
@@ -12,6 +11,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_Open_dir(t *testing.T) {
@@ -134,14 +134,16 @@ func Test_BeginTx(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.db"))
db, err := sql.Open("sqlite3", "file:"+
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db"))+
"?_txlock=exclusive&_pragma=busy_timeout(0)")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err.Error() != string(isolationErr) {
if err.Error() != string(util.IsolationErr) {
t.Error("want isolationErr")
}
@@ -219,7 +221,7 @@ func Test_Prepare(t *testing.T) {
}
_, err = db.Prepare(`SELECT 1; SELECT 2`)
if err.Error() != string(tailErr) {
if err.Error() != string(util.TailErr) {
t.Error("want tailErr")
}
}
@@ -297,7 +299,7 @@ func Test_QueryRow_blob_null(t *testing.T) {
want := [][]byte{nil, {0xca, 0xfe}, {0xba, 0xbe}, nil}
for i := 0; rows.Next(); i++ {
var buf []byte
var buf sql.RawBytes
err = rows.Scan(&buf)
if err != nil {
t.Fatal(err)

View File

@@ -1,11 +0,0 @@
package driver
type errorString string
func (e errorString) Error() string { return string(e) }
const (
assertErr = errorString("sqlite3: assertion failed")
tailErr = errorString("sqlite3: multiple statements")
isolationErr = errorString("sqlite3: unsupported isolation level")
)

View File

@@ -9,23 +9,23 @@ import (
// if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) driver.Value {
func stringOrTime(text []byte) driver.Value {
// Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") {
return text
return string(text)
}
if len(text) > len(time.RFC3339Nano) {
return text
return string(text)
}
if text[4] != '-' || text[10] != 'T' || text[16] != ':' {
return text
return string(text)
}
// Slow path.
date, err := time.Parse(time.RFC3339Nano, text)
if err == nil && date.Format(time.RFC3339Nano) == text {
date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && date.Format(time.RFC3339Nano) == string(text) {
return date
}
return text
return string(text)
}

View File

@@ -6,7 +6,7 @@ import (
)
// This checks that any string can be recovered as the same string.
func Fuzz_maybeTime_1(f *testing.F) {
func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("")
f.Add(" ")
f.Add("SQLite")
@@ -22,7 +22,7 @@ func Fuzz_maybeTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) {
value := maybeTime(str)
value := stringOrTime([]byte(str))
switch v := value.(type) {
case time.Time:
@@ -48,7 +48,7 @@ func Fuzz_maybeTime_1(f *testing.F) {
// This checks that any [time.Time] can be recovered as a [time.Time],
// with nanosecond accuracy, and preserving any timezone offset.
func Fuzz_maybeTime_2(f *testing.F) {
func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(0, 0)
f.Add(0, 1)
f.Add(0, -1)
@@ -59,7 +59,7 @@ func Fuzz_maybeTime_2(f *testing.F) {
f.Add(-763421161058, 222_222_222) // twosday, year 22222BC
checkTime := func(t *testing.T, date time.Time) {
value := maybeTime(date.Format(time.RFC3339Nano))
value := stringOrTime([]byte(date.Format(time.RFC3339Nano)))
switch v := value.(type) {
case time.Time:

View File

@@ -48,7 +48,8 @@ func ExampleDriverConn() {
err = conn.Raw(func(driverConn any) error {
conn := driverConn.(sqlite3.DriverConn)
defer conn.Savepoint()(&err)
savept := conn.Savepoint()
defer savept.Release(&err)
blob, err := conn.OpenBlob("main", "test", "col", id, true)
if err != nil {

23
embed/README.md Normal file
View File

@@ -0,0 +1,23 @@
# Embeddable WASM build of SQLite
This folder includes an embeddable WASM build of SQLite 3.41.2 for use with
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
The following optional features are compiled in:
- [math functions](https://www.sqlite.org/lang_mathfunc.html)
- [FTS3/4](https://www.sqlite.org/fts3.html)/[5](https://www.sqlite.org/fts5.html)
- [JSON](https://www.sqlite.org/json1.html)
- [R*Tree](https://www.sqlite.org/rtree.html)
- [GeoPoly](https://www.sqlite.org/geopoly.html)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)
- [decimal](https://github.com/sqlite/sqlite/blob/master/ext/misc/decimal.c)
- [regexp](https://github.com/sqlite/sqlite/blob/master/ext/misc/regexp.c)
- [series](https://github.com/sqlite/sqlite/blob/master/ext/misc/series.c)
- [uint](https://github.com/sqlite/sqlite/blob/master/ext/misc/uint.c)
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
- [time](../sqlite3/time.c)
See the [configuration options](../sqlite3/sqlite_cfg.h).
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
and [`binaryen`](https://github.com/WebAssembly/binaryen).

View File

@@ -1,16 +1,28 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
# download SQLite
../sqlite3/download.sh
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
# build SQLite
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o sqlite3.wasm ../sqlite3/amalg.c \
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \
-mexec-model=reactor \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--initial-memory=327680 \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
$(awk '{print "-Wl,--export="$0}' ../sqlite3/exports.txt)
$(awk '{print "-Wl,--export="$0}' exports.txt)
trap 'rm -f sqlite3.tmp' EXIT
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext

View File

@@ -40,7 +40,6 @@ sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_unlock_notify
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish

View File

@@ -4,9 +4,6 @@
// with an appropriate build of SQLite:
//
// import _ "github.com/ncruces/go-sqlite3/embed"
//
// You can obtain this build of SQLite from:
// https://github.com/ncruces/go-sqlite3/tree/main/embed
package embed
import (

Binary file not shown.

View File

@@ -1,8 +1,6 @@
package sqlite3
import (
"fmt"
"runtime"
"strconv"
"strings"
)
@@ -11,10 +9,10 @@ import (
//
// https://www.sqlite.org/c3ref/errcode.html
type Error struct {
code uint64
str string
msg string
sql string
code uint64
}
// Code returns the primary error code for this error.
@@ -188,39 +186,3 @@ func (e ExtendedErrorCode) Temporary() bool {
func (e ExtendedErrorCode) Timeout() bool {
return e == BUSY_TIMEOUT
}
type errorString string
func (e errorString) Error() string { return string(e) }
const (
nilErr = errorString("sqlite3: invalid memory address or null pointer dereference")
oomErr = errorString("sqlite3: out of memory")
rangeErr = errorString("sqlite3: index out of range")
noNulErr = errorString("sqlite3: missing NUL terminator")
noGlobalErr = errorString("sqlite3: could not find global: ")
noFuncErr = errorString("sqlite3: could not find function: ")
binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded")
timeErr = errorString("sqlite3: invalid time value")
emptyErr = errorString("sqlite3: empty statement")
tailErr = errorString("sqlite3: non-empty tail")
notImplErr = errorString("sqlite3: not implemented")
whenceErr = errorString("sqlite3: invalid whence")
offsetErr = errorString("sqlite3: invalid offset")
)
func assertErr() errorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return errorString(msg)
}
func finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(errorString(msg)) }
}

View File

@@ -1,15 +1,16 @@
package sqlite3
import (
"context"
"errors"
"strings"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_assertErr(t *testing.T) {
err := assertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") {
err := util.AssertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") {
t.Errorf("got %q", s)
}
}
@@ -120,10 +121,8 @@ func Test_ErrorCode_Error(t *testing.T) {
// Test all error codes.
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
@@ -144,10 +143,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r, _ := db.api.errstr.Call(context.TODO(), uint64(i))
if r != nil {
want += db.mem.readString(uint32(r[0]), _MAX_STRING)
}
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {

View File

@@ -26,7 +26,11 @@ func Example() {
log.Fatal(err)
}
stmt := db.MustPrepare(`SELECT id, name FROM users`)
stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()
for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))

4
go.mod
View File

@@ -4,9 +4,9 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.0.0-rc.1
github.com/tetratelabs/wazero v1.1.0
golang.org/x/sync v0.1.0
golang.org/x/sys v0.6.0
golang.org/x/sys v0.7.0
)
retract v0.4.0 // tagged from the wrong branch

8
go.sum
View File

@@ -1,8 +1,8 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.0.0-rc.1 h1:ytecMV5Ue0BwezjKh/cM5yv1Mo49ep2R2snSsQUyToc=
github.com/tetratelabs/wazero v1.0.0-rc.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ=
github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

42
internal/util/error.go Normal file
View File

@@ -0,0 +1,42 @@
package util
import (
"fmt"
"runtime"
"strconv"
)
type ErrorString string
func (e ErrorString) Error() string { return string(e) }
const (
NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference")
OOMErr = ErrorString("sqlite3: out of memory")
RangeErr = ErrorString("sqlite3: index out of range")
NoNulErr = ErrorString("sqlite3: missing NUL terminator")
NoGlobalErr = ErrorString("sqlite3: could not find global: ")
NoFuncErr = ErrorString("sqlite3: could not find function: ")
BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded")
TimeErr = ErrorString("sqlite3: invalid time value")
WhenceErr = ErrorString("sqlite3: invalid whence")
OffsetErr = ErrorString("sqlite3: invalid offset")
TailErr = ErrorString("sqlite3: multiple statements")
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
)
func AssertErr() ErrorString {
msg := "sqlite3: assertion failed"
if _, file, line, ok := runtime.Caller(1); ok {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return ErrorString(msg)
}
func Finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(ErrorString(msg)) }
}

81
internal/util/func.go Normal file
View File

@@ -0,0 +1,81 @@
package util
import (
"context"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
type i32 interface{ ~int32 | ~uint32 }
type i64 interface{ ~int64 | ~uint64 }
func RegisterFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
}),
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func RegisterFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func RegisterFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func RegisterFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func RegisterFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3, _ T4) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func RegisterFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}
func RegisterFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) TR) {
mod.NewFunctionBuilder().
WithGoModuleFunction(api.GoModuleFunc(
func(ctx context.Context, mod api.Module, stack []uint64) {
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
}),
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
Export(name)
}

110
internal/util/mem.go Normal file
View File

@@ -0,0 +1,110 @@
package util
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
func View(mod api.Module, ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(NilErr)
}
if size > math.MaxUint32 {
panic(RangeErr)
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)
}
return buf
}
func ReadUint32(mod api.Module, ptr uint32) uint32 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint32(mod api.Module, ptr uint32, v uint32) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadUint64(mod api.Module, ptr uint32) uint64 {
if ptr == 0 {
panic(NilErr)
}
v, ok := mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(RangeErr)
}
return v
}
func WriteUint64(mod api.Module, ptr uint32, v uint64) {
if ptr == 0 {
panic(NilErr)
}
ok := mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(RangeErr)
}
}
func ReadFloat64(mod api.Module, ptr uint32) float64 {
return math.Float64frombits(ReadUint64(mod, ptr))
}
func WriteFloat64(mod api.Module, ptr uint32, v float64) {
WriteUint64(mod, ptr, math.Float64bits(v))
}
func ReadString(mod api.Module, ptr, maxlen uint32) string {
if ptr == 0 {
panic(NilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
return string(buf[:i])
}
}
func WriteBytes(mod api.Module, ptr uint32, b []byte) {
buf := View(mod, ptr, uint64(len(b)))
copy(buf, b)
}
func WriteString(mod api.Module, ptr uint32, s string) {
buf := View(mod, ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

90
internal/util/mem_test.go Normal file
View File

@@ -0,0 +1,90 @@
package util
import (
"math"
"testing"
)
func TestView_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 0, 8)
t.Error("want panic")
}
func TestView_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 126, 8)
t.Error("want panic")
}
func TestView_overflow(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 1, math.MaxInt64)
t.Error("want panic")
}
func TestReadUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 0)
t.Error("want panic")
}
func TestReadUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 126)
t.Error("want panic")
}
func TestReadUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 0)
t.Error("want panic")
}
func TestReadUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 126)
t.Error("want panic")
}
func TestWriteUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 126, 1)
t.Error("want panic")
}
func TestWriteUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 126, 1)
t.Error("want panic")
}
func TestReadString_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadString(mock, 130, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -1,63 +1,54 @@
package sqlite3
package util
import (
"context"
"encoding/binary"
"math"
"github.com/tetratelabs/wazero/api"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func newMemory(size uint32) memory {
mem := make(mockMemory, size)
return memory{mockModule{&mem}}
func NewMockModule(size uint32) api.Module {
mem := mockMemory{buf: make([]byte, size)}
return mockModule{&mem, nil}
}
type mockModule struct {
memory api.Memory
api.Module
}
func (m mockModule) Memory() api.Memory { return m.memory }
func (m mockModule) String() string { return "mockModule" }
func (m mockModule) Name() string { return "mockModule" }
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
func (m mockModule) Close(context.Context) error { return nil }
type mockMemory []byte
type mockMemory struct {
buf []byte
api.Memory
}
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
func (m mockMemory) Size() uint32 { return uint32(len(m.buf)) }
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
if offset >= m.Size() {
return 0, false
}
return m[offset], true
return m.buf[offset], true
}
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
if !m.hasSize(offset, 2) {
return 0, false
}
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
return binary.LittleEndian.Uint16(m.buf[offset : offset+2]), true
}
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
if !m.hasSize(offset, 4) {
return 0, false
}
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
return binary.LittleEndian.Uint32(m.buf[offset : offset+4]), true
}
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
@@ -72,7 +63,7 @@ func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
if !m.hasSize(offset, 8) {
return 0, false
}
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
return binary.LittleEndian.Uint64(m.buf[offset : offset+8]), true
}
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
@@ -87,14 +78,14 @@ func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
if !m.hasSize(offset, byteCount) {
return nil, false
}
return m[offset : offset+byteCount : offset+byteCount], true
return m.buf[offset : offset+byteCount : offset+byteCount], true
}
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
if offset >= m.Size() {
return false
}
m[offset] = v
m.buf[offset] = v
return true
}
@@ -102,7 +93,7 @@ func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
if !m.hasSize(offset, 2) {
return false
}
binary.LittleEndian.PutUint16(m[offset:], v)
binary.LittleEndian.PutUint16(m.buf[offset:], v)
return true
}
@@ -110,7 +101,7 @@ func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
if !m.hasSize(offset, 4) {
return false
}
binary.LittleEndian.PutUint32(m[offset:], v)
binary.LittleEndian.PutUint32(m.buf[offset:], v)
return true
}
@@ -122,7 +113,7 @@ func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
if !m.hasSize(offset, 8) {
return false
}
binary.LittleEndian.PutUint64(m[offset:], v)
binary.LittleEndian.PutUint64(m.buf[offset:], v)
return true
}
@@ -134,7 +125,7 @@ func (m mockMemory) Write(offset uint32, val []byte) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
copy(m.buf[offset:], val)
return true
}
@@ -142,20 +133,16 @@ func (m mockMemory) WriteString(offset uint32, val string) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m[offset:], val)
copy(m.buf[offset:], val)
return true
}
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
prev := (len(*m) + 65535) / 65536
*m = append(*m, make([]byte, 65536*delta)...)
prev := (len(m.buf) + 65535) / 65536
m.buf = append(m.buf, make([]byte, 65536*delta)...)
return uint32(prev), true
}
func (m mockMemory) PageSize() (result uint32) {
return uint32(len(m) / 65536)
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
return uint64(offset)+uint64(byteCount) <= uint64(len(m.buf))
}

242
internal/util/mock_test.go Normal file
View File

@@ -0,0 +1,242 @@
package util
import (
"math"
"testing"
)
func Test_mockMemory_byte(t *testing.T) {
const want byte = 98
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadByte(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint16(t *testing.T) {
const want uint16 = 9876
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint16Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint16Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint32(t *testing.T) {
const want uint32 = 987654321
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint64(t *testing.T) {
const want uint64 = 9876543210
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_float32(t *testing.T) {
const want float32 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_float64(t *testing.T) {
const want float64 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_bytes(t *testing.T) {
const want string = "\xca\xfe\xba\xbe"
mock := NewMockModule(128)
_, ok := mock.Memory().Read(128, uint32(len(want)))
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(128, []byte(want))
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteString(128, want)
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(0, []byte(want))
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().Read(0, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
ok = mock.Memory().WriteString(64, want)
if !ok {
t.Error("want ok")
}
got, ok = mock.Memory().Read(64, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func Test_mockMemory_grow(t *testing.T) {
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(65536)
if ok {
t.Error("want error")
}
got, ok := mock.Memory().Grow(1)
if !ok {
t.Error("want ok")
}
if got != 1 {
t.Errorf("got %d, want 1", got)
}
_, ok = mock.Memory().ReadByte(65536)
if !ok {
t.Error("want ok")
}
}

217
internal/vfs/const.go Normal file
View File

@@ -0,0 +1,217 @@
package vfs
const (
_MAX_PATHNAME = 512
_DEFAULT_SECTOR_SIZE = 4096
)
// https://www.sqlite.org/rescode.html
type _ErrorCode uint32
const (
_OK _ErrorCode = 0 /* Successful result */
_PERM _ErrorCode = 3 /* Access permission denied */
_BUSY _ErrorCode = 5 /* The database file is locked */
_IOERR _ErrorCode = 10 /* Some kind of disk I/O error occurred */
_NOTFOUND _ErrorCode = 12 /* Unknown opcode in sqlite3_file_control() */
_CANTOPEN _ErrorCode = 14 /* Unable to open the database file */
_IOERR_READ = _IOERR | (1 << 8)
_IOERR_SHORT_READ = _IOERR | (2 << 8)
_IOERR_WRITE = _IOERR | (3 << 8)
_IOERR_FSYNC = _IOERR | (4 << 8)
_IOERR_DIR_FSYNC = _IOERR | (5 << 8)
_IOERR_TRUNCATE = _IOERR | (6 << 8)
_IOERR_FSTAT = _IOERR | (7 << 8)
_IOERR_UNLOCK = _IOERR | (8 << 8)
_IOERR_RDLOCK = _IOERR | (9 << 8)
_IOERR_DELETE = _IOERR | (10 << 8)
_IOERR_BLOCKED = _IOERR | (11 << 8)
_IOERR_NOMEM = _IOERR | (12 << 8)
_IOERR_ACCESS = _IOERR | (13 << 8)
_IOERR_CHECKRESERVEDLOCK = _IOERR | (14 << 8)
_IOERR_LOCK = _IOERR | (15 << 8)
_IOERR_CLOSE = _IOERR | (16 << 8)
_IOERR_DIR_CLOSE = _IOERR | (17 << 8)
_IOERR_SHMOPEN = _IOERR | (18 << 8)
_IOERR_SHMSIZE = _IOERR | (19 << 8)
_IOERR_SHMLOCK = _IOERR | (20 << 8)
_IOERR_SHMMAP = _IOERR | (21 << 8)
_IOERR_SEEK = _IOERR | (22 << 8)
_IOERR_DELETE_NOENT = _IOERR | (23 << 8)
_IOERR_MMAP = _IOERR | (24 << 8)
_IOERR_GETTEMPPATH = _IOERR | (25 << 8)
_IOERR_CONVPATH = _IOERR | (26 << 8)
_IOERR_VNODE = _IOERR | (27 << 8)
_IOERR_AUTH = _IOERR | (28 << 8)
_IOERR_BEGIN_ATOMIC = _IOERR | (29 << 8)
_IOERR_COMMIT_ATOMIC = _IOERR | (30 << 8)
_IOERR_ROLLBACK_ATOMIC = _IOERR | (31 << 8)
_IOERR_DATA = _IOERR | (32 << 8)
_IOERR_CORRUPTFS = _IOERR | (33 << 8)
_CANTOPEN_NOTEMPDIR = _CANTOPEN | (1 << 8)
_CANTOPEN_ISDIR = _CANTOPEN | (2 << 8)
_CANTOPEN_FULLPATH = _CANTOPEN | (3 << 8)
_CANTOPEN_CONVPATH = _CANTOPEN | (4 << 8)
_CANTOPEN_DIRTYWAL = _CANTOPEN | (5 << 8) /* Not Used */
_CANTOPEN_SYMLINK = _CANTOPEN | (6 << 8)
_OK_SYMLINK = _OK | (2 << 8) /* internal use only */
)
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
type _OpenFlag uint32
const (
_OPEN_READONLY _OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
_OPEN_READWRITE _OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
_OPEN_CREATE _OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
_OPEN_DELETEONCLOSE _OpenFlag = 0x00000008 /* VFS only */
_OPEN_EXCLUSIVE _OpenFlag = 0x00000010 /* VFS only */
_OPEN_AUTOPROXY _OpenFlag = 0x00000020 /* VFS only */
_OPEN_URI _OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
_OPEN_MEMORY _OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
_OPEN_MAIN_DB _OpenFlag = 0x00000100 /* VFS only */
_OPEN_TEMP_DB _OpenFlag = 0x00000200 /* VFS only */
_OPEN_TRANSIENT_DB _OpenFlag = 0x00000400 /* VFS only */
_OPEN_MAIN_JOURNAL _OpenFlag = 0x00000800 /* VFS only */
_OPEN_TEMP_JOURNAL _OpenFlag = 0x00001000 /* VFS only */
_OPEN_SUBJOURNAL _OpenFlag = 0x00002000 /* VFS only */
_OPEN_SUPER_JOURNAL _OpenFlag = 0x00004000 /* VFS only */
_OPEN_NOMUTEX _OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
_OPEN_FULLMUTEX _OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
_OPEN_SHAREDCACHE _OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
_OPEN_PRIVATECACHE _OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
_OPEN_WAL _OpenFlag = 0x00080000 /* VFS only */
_OPEN_NOFOLLOW _OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
_OPEN_EXRESCODE _OpenFlag = 0x02000000 /* Extended result codes */
)
// https://www.sqlite.org/c3ref/c_access_exists.html
type _AccessFlag uint32
const (
_ACCESS_EXISTS _AccessFlag = 0
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
_ACCESS_READ _AccessFlag = 2 /* Unused */
)
// https://www.sqlite.org/c3ref/c_sync_dataonly.html
type _SyncFlag uint32
const (
_SYNC_NORMAL _SyncFlag = 0x00002
_SYNC_FULL _SyncFlag = 0x00003
_SYNC_DATAONLY _SyncFlag = 0x00010
)
// https://www.sqlite.org/c3ref/c_lock_exclusive.html
type _LockLevel uint32
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
_LOCK_NONE _LockLevel = 0 /* xUnlock() only */
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
_LOCK_SHARED _LockLevel = 1 /* xLock() or xUnlock() */
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
_LOCK_RESERVED _LockLevel = 2 /* xLock() only */
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
_LOCK_PENDING _LockLevel = 3 /* internal use only */
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
_LOCK_EXCLUSIVE _LockLevel = 4 /* xLock() only */
)
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type _FcntlOpcode uint32
const (
_FCNTL_LOCKSTATE _FcntlOpcode = 1
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
_FCNTL_SIZE_HINT _FcntlOpcode = 5
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
_FCNTL_FILE_POINTER _FcntlOpcode = 7
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
_FCNTL_OVERWRITE _FcntlOpcode = 11
_FCNTL_VFSNAME _FcntlOpcode = 12
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
_FCNTL_PRAGMA _FcntlOpcode = 14
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
_FCNTL_TRACE _FcntlOpcode = 19
_FCNTL_HAS_MOVED _FcntlOpcode = 20
_FCNTL_SYNC _FcntlOpcode = 21
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
_FCNTL_ZIPVFS _FcntlOpcode = 25
_FCNTL_RBU _FcntlOpcode = 26
_FCNTL_VFS_POINTER _FcntlOpcode = 27
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
_FCNTL_PDB _FcntlOpcode = 30
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
_FCNTL_DATA_VERSION _FcntlOpcode = 35
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
_FCNTL_CKPT_DONE _FcntlOpcode = 37
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
_FCNTL_CKPT_START _FcntlOpcode = 39
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
_FCNTL_CKSM_FILE _FcntlOpcode = 41
_FCNTL_RESET_CACHE _FcntlOpcode = 42
)
// https://www.sqlite.org/c3ref/c_iocap_atomic.html
type _DeviceCharacteristic uint32
const (
_IOCAP_ATOMIC _DeviceCharacteristic = 0x00000001
_IOCAP_ATOMIC512 _DeviceCharacteristic = 0x00000002
_IOCAP_ATOMIC1K _DeviceCharacteristic = 0x00000004
_IOCAP_ATOMIC2K _DeviceCharacteristic = 0x00000008
_IOCAP_ATOMIC4K _DeviceCharacteristic = 0x00000010
_IOCAP_ATOMIC8K _DeviceCharacteristic = 0x00000020
_IOCAP_ATOMIC16K _DeviceCharacteristic = 0x00000040
_IOCAP_ATOMIC32K _DeviceCharacteristic = 0x00000080
_IOCAP_ATOMIC64K _DeviceCharacteristic = 0x00000100
_IOCAP_SAFE_APPEND _DeviceCharacteristic = 0x00000200
_IOCAP_SEQUENTIAL _DeviceCharacteristic = 0x00000400
_IOCAP_UNDELETABLE_WHEN_OPEN _DeviceCharacteristic = 0x00000800
_IOCAP_POWERSAFE_OVERWRITE _DeviceCharacteristic = 0x00001000
_IOCAP_IMMUTABLE _DeviceCharacteristic = 0x00002000
_IOCAP_BATCH_ATOMIC _DeviceCharacteristic = 0x00004000
)

View File

@@ -16,13 +16,11 @@ import (
"sync/atomic"
"testing"
_ "unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
_ "github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/vfs"
)
//go:embed testdata/mptest.wasm
@@ -31,12 +29,6 @@ var binary []byte
//go:embed testdata/*.*test
var scripts embed.FS
//go:linkname vfsNewEnvModuleBuilder github.com/ncruces/go-sqlite3.vfsNewEnvModuleBuilder
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder
//go:linkname vfsContext github.com/ncruces/go-sqlite3.vfsContext
func vfsContext(ctx context.Context) (context.Context, io.Closer)
var (
rt wazero.Runtime
module wazero.CompiledModule
@@ -48,7 +40,8 @@ func init() {
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfsNewEnvModuleBuilder(rt)
env := vfs.Export(rt.NewHostModuleBuilder("env"))
env.NewFunctionBuilder().WithFunc(system).Export("system")
_, err := env.Instantiate(ctx)
if err != nil {
@@ -88,7 +81,7 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := vfsContext(ctx)
ctx, vfs := vfs.Context(ctx)
rt.InstantiateModule(ctx, module, cfg)
vfs.Close()
}()
@@ -96,14 +89,15 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
}
func Test_config01(t *testing.T) {
ctx, vfs := vfsContext(newContext(t))
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_config02(t *testing.T) {
@@ -114,14 +108,15 @@ func Test_config02(t *testing.T) {
t.Skip("skipping in CI")
}
ctx, vfs := vfsContext(newContext(t))
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_crash01(t *testing.T) {
@@ -129,14 +124,15 @@ func Test_crash01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfsContext(newContext(t))
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func Test_multiwrite01(t *testing.T) {
@@ -144,14 +140,15 @@ func Test_multiwrite01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := vfsContext(newContext(t))
ctx, vfs := vfs.Context(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
_, err := rt.InstantiateModule(ctx, module, cfg)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
vfs.Close()
mod.Close(ctx)
}
func newContext(t *testing.T) context.Context {

29
internal/vfs/tests/mptest/testdata/build.sh vendored Executable file
View File

@@ -0,0 +1,29 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o mptest.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid
"$BINARYEN/wasm-opt" -g -O2 mptest.wasm -o mptest.tmp \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv mptest.tmp mptest.wasm

View File

@@ -1,22 +1,22 @@
#include <stdbool.h>
#include <stddef.h>
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
//
#include "os.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
void __attribute__((constructor)) premain() { sqlite3_initialize(); }
int sqlite3_enable_load_extension(sqlite3 *db, int onoff) { return SQLITE_OK; }
void *sqlite3_trace(sqlite3 *db, void (*xTrace)(void *, const char *),
void *pArg) {
return NULL;
}
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
}
__attribute__((constructor)) void premain() { sqlite3_initialize(); }
static int dont_unlink(const char *pathname) { return 0; }
#define sqlite3_enable_load_extension(...)
#define sqlite3_trace(...)
#define unlink dont_unlink
#undef UNUSED_PARAMETER
#include "mptest.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:37ceeed293b9f09e9770b40eda3f625447f5a3a74208709886d4411d12f93414
size 1486113

View File

@@ -0,0 +1,85 @@
package speedtest1
import (
"bytes"
"context"
"crypto/rand"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
_ "embed"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/internal/vfs"
)
//go:embed testdata/speedtest1.wasm
var binary []byte
var (
rt wazero.Runtime
module wazero.CompiledModule
output bytes.Buffer
options []string
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := vfs.Export(rt.NewHostModuleBuilder("env"))
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
}
module, err = rt.CompileModule(ctx, binary)
if err != nil {
panic(err)
}
}
func TestMain(m *testing.M) {
i := 1
options = append(options, "speedtest1")
for _, arg := range os.Args[1:] {
if strings.HasPrefix(arg, "-test.") {
os.Args[i] = arg
i++
} else {
options = append(options, arg)
}
}
os.Args = os.Args[:i]
code := m.Run()
io.Copy(os.Stderr, &output)
os.Exit(code)
}
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := vfs.Context(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().
WithArgs(args...).WithName("speedtest1").
WithStdout(&output).WithStderr(&output).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
b.Error(err)
}
vfs.Close()
mod.Close(ctx)
}

View File

@@ -0,0 +1 @@
speedtest1.wasm filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1 @@
speedtest1.c

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
-o speedtest1.wasm main.c \
-I"$ROOT/sqlite3" \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-Wl,--stack-first \
-Wl,--import-undefined \
-D_HAVE_SQLITE_CONFIG_H
"$BINARYEN/wasm-opt" -g -O2 speedtest1.wasm -o speedtest1.tmp \
--enable-multivalue --enable-mutable-globals \
--enable-bulk-memory --enable-reference-types \
--enable-nontrapping-float-to-int --enable-sign-ext
mv speedtest1.tmp speedtest1.wasm

View File

@@ -0,0 +1,16 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.c"
//
#include "os.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
#define randomFunc(args...) randomFunc2(args)
#include "speedtest1.c"

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8167119c344a68217b0301e2e8c288f2e75611d296d7822f841b65911da0275c
size 1520569

390
internal/vfs/vfs.go Normal file
View File

@@ -0,0 +1,390 @@
package vfs
import (
"context"
"crypto/rand"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
func Export(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.RegisterFuncIIJ(env, "os_localtime", vfsLocaltime)
util.RegisterFuncIIII(env, "os_randomness", vfsRandomness)
util.RegisterFuncIII(env, "os_sleep", vfsSleep)
util.RegisterFuncIII(env, "os_current_time", vfsCurrentTime)
util.RegisterFuncIII(env, "os_current_time_64", vfsCurrentTime64)
util.RegisterFuncIIIII(env, "os_full_pathname", vfsFullPathname)
util.RegisterFuncIIII(env, "os_delete", vfsDelete)
util.RegisterFuncIIIII(env, "os_access", vfsAccess)
util.RegisterFuncIIIIII(env, "os_open", vfsOpen)
util.RegisterFuncII(env, "os_close", vfsClose)
util.RegisterFuncIIIIJ(env, "os_read", vfsRead)
util.RegisterFuncIIIIJ(env, "os_write", vfsWrite)
util.RegisterFuncIIJ(env, "os_truncate", vfsTruncate)
util.RegisterFuncIII(env, "os_sync", vfsSync)
util.RegisterFuncIII(env, "os_file_size", vfsFileSize)
util.RegisterFuncIIII(env, "os_file_control", vfsFileControl)
util.RegisterFuncII(env, "os_sector_size", vfsSectorSize)
util.RegisterFuncII(env, "os_device_characteristics", vfsDeviceCharacteristics)
util.RegisterFuncIII(env, "os_lock", vfsLock)
util.RegisterFuncIII(env, "os_unlock", vfsUnlock)
util.RegisterFuncIII(env, "os_check_reserved_lock", vfsCheckReservedLock)
return env
}
type vfsKey struct{}
type vfsState struct {
files []vfsFile
}
func Context(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f.File != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode {
tm := time.Unix(t, 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
const size = 32 / 8
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
util.WriteUint32(mod, pTm+0*size, uint32(tm.Second()))
util.WriteUint32(mod, pTm+1*size, uint32(tm.Minute()))
util.WriteUint32(mod, pTm+2*size, uint32(tm.Hour()))
util.WriteUint32(mod, pTm+3*size, uint32(tm.Day()))
util.WriteUint32(mod, pTm+4*size, uint32(tm.Month()-time.January))
util.WriteUint32(mod, pTm+5*size, uint32(tm.Year()-1900))
util.WriteUint32(mod, pTm+6*size, uint32(tm.Weekday()-time.Sunday))
util.WriteUint32(mod, pTm+7*size, uint32(tm.YearDay()-1))
util.WriteUint32(mod, pTm+8*size, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := util.View(mod, zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, mod api.Module, pVfs, nMicro uint32) _ErrorCode {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) _ErrorCode {
day := julianday.Float(time.Now())
util.WriteFloat64(mod, prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
util.WriteUint64(mod, piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) _ErrorCode {
rel := util.ReadString(mod, zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return _CANTOPEN_FULLPATH
}
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return _CANTOPEN_FULLPATH
}
mem := util.View(mod, zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
if fi, err := os.Lstat(abs); err == nil {
if fi.Mode()&fs.ModeSymlink != 0 {
return _OK_SYMLINK
}
return _OK
} else if errors.Is(err, fs.ErrNotExist) {
return _OK
}
return _CANTOPEN_FULLPATH
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode {
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
if err != nil {
return _IOERR_DELETE
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) _ErrorCode {
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
err := osAccess(path, flags)
var res uint32
var rc _ErrorCode
if flags == _ACCESS_EXISTS {
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
rc = _IOERR_ACCESS
}
} else {
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrPermission):
res = 0
default:
rc = _IOERR_ACCESS
}
}
util.WriteUint32(mod, pResOut, res)
return rc
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags _OpenFlag, pOutFlags uint32) _ErrorCode {
var oflags int
if flags&_OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&_OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&_OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&_OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var f *os.File
if zName == 0 {
f, err = os.CreateTemp("", "*.db")
} else {
name := util.ReadString(mod, zName, _MAX_PATHNAME)
f, err = osOpenFile(name, oflags, 0666)
}
if err != nil {
return _CANTOPEN
}
if flags&_OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
file := openVFSFile(ctx, mod, pFile, f)
file.psow = true
file.readOnly = flags&_OPEN_READONLY != 0
file.syncDir = runtime.GOOS != "windows" &&
flags&(_OPEN_CREATE) != 0 &&
flags&(_OPEN_MAIN_JOURNAL|_OPEN_SUPER_JOURNAL|_OPEN_WAL) != 0
if pOutFlags != 0 {
util.WriteUint32(mod, pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode {
err := closeVFSFile(ctx, mod, pFile)
if err != nil {
return _IOERR_CLOSE
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
buf := util.View(mod, zBuf, uint64(iAmt))
file := getVFSFile(ctx, mod, pFile)
n, err := file.ReadAt(buf, iOfst)
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return _IOERR_READ
}
for i := range buf[n:] {
buf[n+i] = 0
}
return _IOERR_SHORT_READ
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
buf := util.View(mod, zBuf, uint64(iAmt))
file := getVFSFile(ctx, mod, pFile)
_, err := file.WriteAt(buf, iOfst)
if err != nil {
return _IOERR_WRITE
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
err := file.Truncate(nByte)
if err != nil {
return _IOERR_TRUNCATE
}
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) _ErrorCode {
dataonly := (flags & _SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == _SYNC_FULL
file := getVFSFile(ctx, mod, pFile)
err := osSync(file.File, fullsync, dataonly)
if err != nil {
return _IOERR_FSYNC
}
if runtime.GOOS != "windows" && file.syncDir {
file.syncDir = false
f, err := os.Open(filepath.Dir(file.Name()))
if err != nil {
return _OK
}
defer f.Close()
err = osSync(f, false, false)
if err != nil {
return _IOERR_DIR_FSYNC
}
}
return _OK
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return _IOERR_SEEK
}
util.WriteUint64(mod, pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode {
switch op {
case _FCNTL_LOCKSTATE:
util.WriteUint32(mod, pArg, uint32(getVFSFile(ctx, mod, pFile).lock))
return _OK
case _FCNTL_LOCK_TIMEOUT:
file := getVFSFile(ctx, mod, pFile)
millis := file.lockTimeout.Milliseconds()
file.lockTimeout = time.Duration(util.ReadUint32(mod, pArg)) * time.Millisecond
util.WriteUint32(mod, pArg, uint32(millis))
return _OK
case _FCNTL_POWERSAFE_OVERWRITE:
file := getVFSFile(ctx, mod, pFile)
switch util.ReadUint32(mod, pArg) {
case 0:
file.psow = false
case 1:
file.psow = true
default:
if file.psow {
util.WriteUint32(mod, pArg, 1)
} else {
util.WriteUint32(mod, pArg, 0)
}
}
case _FCNTL_SIZE_HINT:
return vfsSizeHint(ctx, mod, pFile, pArg)
case _FCNTL_HAS_MOVED:
return vfsFileMoved(ctx, mod, pFile, pArg)
}
// Consider also implementing these opcodes (in use by SQLite):
// _FCNTL_BUSYHANDLER
// _FCNTL_COMMIT_PHASETWO
// _FCNTL_PDB
// _FCNTL_PRAGMA
// _FCNTL_SYNC
return _NOTFOUND
}
func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 {
return _DEFAULT_SECTOR_SIZE
}
func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) _DeviceCharacteristic {
file := getVFSFile(ctx, mod, pFile)
if file.psow {
return _IOCAP_POWERSAFE_OVERWRITE
}
return 0
}
func vfsSizeHint(ctx context.Context, mod api.Module, pFile, pArg uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
size := util.ReadUint64(mod, pArg)
err := osAllocate(file.File, int64(size))
if err != nil {
return _IOERR_TRUNCATE
}
return _OK
}
func vfsFileMoved(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
fi, err := file.Stat()
if err != nil {
return _IOERR_FSTAT
}
pi, err := os.Stat(file.Name())
if err != nil && !errors.Is(err, fs.ErrNotExist) {
return _IOERR_FSTAT
}
var res uint32
if !os.SameFile(fi, pi) {
res = 1
}
util.WriteUint32(mod, pResOut, res)
return _OK
}

54
internal/vfs/vfs_file.go Normal file
View File

@@ -0,0 +1,54 @@
package vfs
import (
"context"
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
type vfsFile struct {
*os.File
lockTimeout time.Duration
lock _LockLevel
psow bool
syncDir bool
readOnly bool
}
func newVFSFile(vfs *vfsState, file *os.File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f.File == nil {
vfs.files[id] = vfsFile{File: file}
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, vfsFile{File: file})
return uint32(len(vfs.files) - 1)
}
func getVFSFile(ctx context.Context, mod api.Module, pFile uint32) *vfsFile {
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+4)
return &vfs.files[id]
}
func openVFSFile(ctx context.Context, mod api.Module, pFile uint32, file *os.File) *vfsFile {
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := newVFSFile(vfs, file)
util.WriteUint32(mod, pFile+4, id)
return &vfs.files[id]
}
func closeVFSFile(ctx context.Context, mod api.Module, pFile uint32) error {
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+4)
file := vfs.files[id]
vfs.files[id] = vfsFile{}
return file.Close()
}

168
internal/vfs/vfs_lock.go Normal file
View File

@@ -0,0 +1,168 @@
package vfs
import (
"context"
"os"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
const (
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock _LockLevel) _ErrorCode {
// Argument check. SQLite never explicitly requests a pending lock.
if eLock != _LOCK_SHARED && eLock != _LOCK_RESERVED && eLock != _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
file := getVFSFile(ctx, mod, pFile)
switch {
case file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
case file.lock == _LOCK_NONE && eLock > _LOCK_SHARED:
// We never move from unlocked to anything higher than a shared lock.
panic(util.AssertErr())
case file.lock != _LOCK_SHARED && eLock == _LOCK_RESERVED:
// A shared lock is always held when a reserved lock is requested.
panic(util.AssertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if file.lock >= eLock {
return _OK
}
// Do not allow any kind of write-lock on a read-only database.
if file.readOnly && eLock >= _LOCK_RESERVED {
return _IOERR_LOCK
}
switch eLock {
case _LOCK_SHARED:
// Must be unlocked to get SHARED.
if file.lock != _LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = _LOCK_SHARED
return _OK
case _LOCK_RESERVED:
// Must be SHARED to get RESERVED.
if file.lock != _LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = _LOCK_RESERVED
return _OK
case _LOCK_EXCLUSIVE:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if file.lock <= _LOCK_NONE || file.lock >= _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if file.lock < _LOCK_PENDING {
if rc := osGetPendingLock(file.File); rc != _OK {
return rc
}
file.lock = _LOCK_PENDING
}
if rc := osGetExclusiveLock(file.File, file.lockTimeout); rc != _OK {
return rc
}
file.lock = _LOCK_EXCLUSIVE
return _OK
default:
panic(util.AssertErr())
}
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock _LockLevel) _ErrorCode {
// Argument check.
if eLock != _LOCK_NONE && eLock != _LOCK_SHARED {
panic(util.AssertErr())
}
file := getVFSFile(ctx, mod, pFile)
// Connection state check.
if file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// If we don't have a more restrictive lock, do nothing.
if file.lock <= eLock {
return _OK
}
switch eLock {
case _LOCK_SHARED:
if rc := osDowngradeLock(file.File, file.lock); rc != _OK {
return rc
}
file.lock = _LOCK_SHARED
return _OK
case _LOCK_NONE:
rc := osReleaseLock(file.File, file.lock)
file.lock = _LOCK_NONE
return rc
default:
panic(util.AssertErr())
}
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
file := getVFSFile(ctx, mod, pFile)
// Connection state check.
if file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
var locked bool
var rc _ErrorCode
if file.lock >= _LOCK_RESERVED {
locked = true
} else {
locked, rc = osCheckReservedLock(file.File)
}
var res uint32
if locked {
res = 1
}
util.WriteUint32(mod, pResOut, res)
return rc
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the RESERVED lock.
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
}
func osGetPendingLock(file *os.File) _ErrorCode {
// Acquire the PENDING lock.
return osWriteLock(file, _PENDING_BYTE, 1, 0)
}
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
// Test the RESERVED lock.
return osCheckLock(file, _RESERVED_BYTE, 1)
}

View File

@@ -0,0 +1,172 @@
package vfs
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mod := util.NewMockModule(128)
ctx, vfs := Context(context.TODO())
defer vfs.Close()
openVFSFile(ctx, mod, pFile1, file1)
openVFSFile(ctx, mod, pFile2, file2)
rc := vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_RESERVED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(ctx, mod, pFile2, _LOCK_EXCLUSIVE)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(ctx, mod, pFile1, _LOCK_SHARED)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsUnlock(ctx, mod, pFile2, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mod, pFile1, _LOCK_SHARED)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -0,0 +1,56 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
if start == 0 && len == 0 {
err := unix.Flock(int(file.Fd()), unix.LOCK_UN)
if err != nil {
return _IOERR_UNLOCK
}
}
return _OK
}
func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = unix.Flock(int(file.Fd()), how)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_SH|unix.LOCK_NB, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.LOCK_EX|unix.LOCK_NB, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,106 @@
//go:build !sqlite3_bsd
package vfs
import (
"io"
"os"
"time"
"golang.org/x/sys/unix"
)
const (
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
_F_OFD_SETLK = 90
_F_OFD_SETLKW = 91
_F_OFD_GETLK = 92
_F_OFD_SETLKWTIMEOUT = 93
)
type flocktimeout_t struct {
fl unix.Flock_t
timeout unix.Timespec
}
func osSync(file *os.File, fullsync, dataonly bool) error {
if fullsync {
return file.Sync()
}
return unix.Fsync(int(file.Fd()))
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
// https://stackoverflow.com/a/11497568/867786
store := unix.Fstore_t{
Flags: unix.F_ALLOCATECONTIG,
Posmode: unix.F_PEOFPOSMODE,
Offset: 0,
Length: size,
}
// Try to get a continous chunk of disk space.
err = unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
if err != nil {
// OK, perhaps we are too fragmented, allocate non-continuous.
store.Flags = unix.F_ALLOCATEALL
unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
}
return file.Truncate(size)
}
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := flocktimeout_t{fl: unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}}
var err error
if timeout == 0 {
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &lock.fl)
} else {
lock.timeout = unix.NsecToTimespec(int64(timeout / time.Nanosecond))
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLKWTIMEOUT, &lock.fl)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), _F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,25 @@
package vfs
import (
"os"
"golang.org/x/sys/unix"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
if dataonly {
_, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0)
if err != 0 {
return err
}
return nil
}
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
if size == 0 {
return nil
}
return unix.Fallocate(int(file.Fd()), 0, 0, size)
}

View File

@@ -0,0 +1,63 @@
//go:build linux || illumos
package vfs
import (
"os"
"time"
"golang.org/x/sys/unix"
)
func osUnlock(file *os.File, start, len int64) _ErrorCode {
err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{
Type: unix.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := unix.Flock_t{
Type: typ,
Start: start,
Len: len,
}
var err error
for {
err = unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &lock)
if errno, _ := err.(unix.Errno); errno != unix.EAGAIN {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode {
return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) {
lock := unix.Flock_t{
Type: unix.F_RDLCK,
Start: start,
Len: len,
}
if unix.FcntlFlock(file.Fd(), unix.F_OFD_GETLK, &lock) != nil {
return false, _IOERR_CHECKRESERVEDLOCK
}
return lock.Type != unix.F_UNLCK, _OK
}

View File

@@ -0,0 +1,23 @@
//go:build !linux && (!darwin || sqlite3_bsd)
package vfs
import (
"io"
"os"
)
func osSync(file *os.File, fullsync, dataonly bool) error {
return file.Sync()
}
func osAllocate(file *os.File, size int64) error {
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return err
}
if size <= off {
return nil
}
return file.Truncate(size)
}

View File

@@ -0,0 +1,87 @@
//go:build unix
package vfs
import (
"io/fs"
"os"
"time"
"golang.org/x/sys/unix"
)
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
return os.OpenFile(name, flag, perm)
}
func osAccess(path string, flags _AccessFlag) error {
var access uint32 // unix.F_OK
switch flags {
case _ACCESS_READWRITE:
access = unix.R_OK | unix.W_OK
case _ACCESS_READ:
access = unix.R_OK
}
return unix.Access(path, access)
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
return _ErrorCode(_BUSY)
}
// Acquire the SHARED lock.
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Acquire the EXCLUSIVE lock.
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
}
func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
if state >= _LOCK_EXCLUSIVE {
// Downgrade to a SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return _IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return osUnlock(file, _PENDING_BYTE, 2)
}
func osReleaseLock(file *os.File, _ _LockLevel) _ErrorCode {
// Release all locks.
return osUnlock(file, 0, 0)
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(unix.Errno); ok {
switch errno {
case
unix.EACCES,
unix.EAGAIN,
unix.EBUSY,
unix.EINTR,
unix.ENOLCK,
unix.EDEADLK,
unix.ETIMEDOUT:
return _ErrorCode(_BUSY)
case unix.EPERM:
return _ErrorCode(_PERM)
}
}
return def
}

View File

@@ -0,0 +1,240 @@
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/windows"
)
// osOpenFile is a simplified copy of [os.openFileNolog]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/os/file_windows.go
//
// See: https://go.dev/issue/32088
func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT}
}
r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e}
}
return os.NewFile(uintptr(r), name), nil
}
func osAccess(path string, flags _AccessFlag) error {
fi, err := os.Stat(path)
if err != nil {
return err
}
if flags == _ACCESS_EXISTS {
return nil
}
var want fs.FileMode = windows.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= windows.S_IWUSR
}
if fi.IsDir() {
want |= windows.S_IXUSR
}
if fi.Mode()&want != want {
return fs.ErrPermission
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
if rc == _OK {
// Acquire the SHARED lock.
rc = osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
// Release the PENDING lock.
osUnlock(file, _PENDING_BYTE, 1)
}
return rc
}
func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
if timeout == 0 {
timeout = time.Millisecond
}
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
if rc != _OK {
// Reacquire the SHARED lock.
osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0)
}
return rc
}
func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
if state >= _LOCK_EXCLUSIVE {
// Release the SHARED lock.
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return _IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= _LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= _LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osReleaseLock(file *os.File, state _LockLevel) _ErrorCode {
// Release all locks.
if state >= _LOCK_RESERVED {
osUnlock(file, _RESERVED_BYTE, 1)
}
if state >= _LOCK_SHARED {
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= _LOCK_PENDING {
osUnlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func osUnlock(file *os.File, start, len uint32) _ErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err == windows.ERROR_NOT_LOCKED {
return _OK
}
if err != nil {
return _IOERR_UNLOCK
}
return _OK
}
func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode {
var err error
for {
err = windows.LockFileEx(windows.Handle(file.Fd()), flags,
0, len, 0, &windows.Overlapped{Offset: start})
if errno, _ := err.(windows.Errno); errno != windows.ERROR_LOCK_VIOLATION {
break
}
if timeout < time.Millisecond {
break
}
timeout -= time.Millisecond
time.Sleep(time.Millisecond)
}
return osLockErrorCode(err, def)
}
func osReadLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, timeout, _IOERR_RDLOCK)
}
func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode {
return osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
start, len, timeout, _IOERR_LOCK)
}
func osCheckLock(file *os.File, start, len uint32) (bool, _ErrorCode) {
rc := osLock(file,
windows.LOCKFILE_FAIL_IMMEDIATELY,
start, len, 0, _IOERR_CHECKRESERVEDLOCK)
if rc == _BUSY {
return true, _OK
}
if rc == _OK {
osUnlock(file, start, len)
}
return false, rc
}
func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(windows.Errno); ok {
// https://devblogs.microsoft.com/oldnewthing/20140905-00/?p=63
switch errno {
case
windows.ERROR_LOCK_VIOLATION,
windows.ERROR_IO_PENDING,
windows.ERROR_OPERATION_ABORTED:
return _BUSY
}
}
return def
}
// syscallOpen is a simplified copy of [syscall.Open]
// that uses syscall.FILE_SHARE_DELETE.
// https://go.dev/src/syscall/syscall_windows.go
func syscallOpen(path string, mode int, perm uint32) (fd syscall.Handle, err error) {
if len(path) == 0 {
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
}
pathp, err := syscall.UTF16PtrFromString(path)
if err != nil {
return syscall.InvalidHandle, err
}
var access uint32
switch mode & (syscall.O_RDONLY | syscall.O_WRONLY | syscall.O_RDWR) {
case syscall.O_RDONLY:
access = syscall.GENERIC_READ
case syscall.O_WRONLY:
access = syscall.GENERIC_WRITE
case syscall.O_RDWR:
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
}
if mode&syscall.O_CREAT != 0 {
access |= syscall.GENERIC_WRITE
}
if mode&syscall.O_APPEND != 0 {
access &^= syscall.GENERIC_WRITE
access |= syscall.FILE_APPEND_DATA
}
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
var createmode uint32
switch {
case mode&(syscall.O_CREAT|syscall.O_EXCL) == (syscall.O_CREAT | syscall.O_EXCL):
createmode = syscall.CREATE_NEW
case mode&(syscall.O_CREAT|syscall.O_TRUNC) == (syscall.O_CREAT | syscall.O_TRUNC):
createmode = syscall.CREATE_ALWAYS
case mode&syscall.O_CREAT == syscall.O_CREAT:
createmode = syscall.OPEN_ALWAYS
case mode&syscall.O_TRUNC == syscall.O_TRUNC:
createmode = syscall.TRUNCATE_EXISTING
default:
createmode = syscall.OPEN_EXISTING
}
var attrs uint32 = syscall.FILE_ATTRIBUTE_NORMAL
if perm&syscall.S_IWRITE == 0 {
attrs = syscall.FILE_ATTRIBUTE_READONLY
}
if createmode == syscall.OPEN_EXISTING && access == syscall.GENERIC_READ {
// Necessary for opening directory handles.
attrs |= syscall.FILE_FLAG_BACKUP_SEMANTICS
}
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, attrs, 0)
}

View File

@@ -1,4 +1,4 @@
package sqlite3
package vfs
import (
"bytes"
@@ -11,72 +11,67 @@ import (
"testing"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
func Test_vfsExit(t *testing.T) {
mem := newMemory(128)
ctx := context.TODO()
defer func() { _ = recover() }()
vfsExit(ctx, mem.mod, 1)
t.Error("want panic")
}
func Test_vfsLocaltime(t *testing.T) {
mem := newMemory(128)
mod := util.NewMockModule(128)
ctx := context.TODO()
rc := vfsLocaltime(ctx, mem.mod, 0, 4)
tm := time.Now()
rc := vfsLocaltime(ctx, mod, 4, tm.Unix())
if rc != 0 {
t.Fatal("returned", rc)
}
epoch := time.Unix(0, 0)
if s := mem.readUint32(4 + 0*4); int(s) != epoch.Second() {
if s := util.ReadUint32(mod, 4+0*4); int(s) != tm.Second() {
t.Error("wrong second")
}
if m := mem.readUint32(4 + 1*4); int(m) != epoch.Minute() {
if m := util.ReadUint32(mod, 4+1*4); int(m) != tm.Minute() {
t.Error("wrong minute")
}
if h := mem.readUint32(4 + 2*4); int(h) != epoch.Hour() {
if h := util.ReadUint32(mod, 4+2*4); int(h) != tm.Hour() {
t.Error("wrong hour")
}
if d := mem.readUint32(4 + 3*4); int(d) != epoch.Day() {
if d := util.ReadUint32(mod, 4+3*4); int(d) != tm.Day() {
t.Error("wrong day")
}
if m := mem.readUint32(4 + 4*4); time.Month(1+m) != epoch.Month() {
if m := util.ReadUint32(mod, 4+4*4); time.Month(1+m) != tm.Month() {
t.Error("wrong month")
}
if y := mem.readUint32(4 + 5*4); 1900+int(y) != epoch.Year() {
if y := util.ReadUint32(mod, 4+5*4); 1900+int(y) != tm.Year() {
t.Error("wrong year")
}
if w := mem.readUint32(4 + 6*4); time.Weekday(w) != epoch.Weekday() {
if w := util.ReadUint32(mod, 4+6*4); time.Weekday(w) != tm.Weekday() {
t.Error("wrong weekday")
}
if d := mem.readUint32(4 + 7*4); int(d) != epoch.YearDay()-1 {
if d := util.ReadUint32(mod, 4+7*4); int(d) != tm.YearDay()-1 {
t.Error("wrong yearday")
}
}
func Test_vfsRandomness(t *testing.T) {
mem := newMemory(128)
mod := util.NewMockModule(128)
ctx := context.TODO()
rc := vfsRandomness(context.TODO(), mem.mod, 0, 16, 4)
rc := vfsRandomness(ctx, mod, 0, 16, 4)
if rc != 16 {
t.Fatal("returned", rc)
}
var zero [16]byte
if got := mem.view(4, 16); bytes.Equal(got, zero[:]) {
if got := util.View(mod, 4, 16); bytes.Equal(got, zero[:]) {
t.Fatal("all zero")
}
}
func Test_vfsSleep(t *testing.T) {
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
rc := vfsSleep(ctx, 0, 123456)
rc := vfsSleep(ctx, mod, 0, 123456)
if rc != 0 {
t.Fatal("returned", rc)
}
@@ -88,56 +83,56 @@ func Test_vfsSleep(t *testing.T) {
}
func Test_vfsCurrentTime(t *testing.T) {
mem := newMemory(128)
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
rc := vfsCurrentTime(ctx, mem.mod, 0, 4)
rc := vfsCurrentTime(ctx, mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
want := julianday.Float(now)
if got := mem.readFloat64(4); float32(got) != float32(want) {
if got := util.ReadFloat64(mod, 4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsCurrentTime64(t *testing.T) {
mem := newMemory(128)
mod := util.NewMockModule(128)
ctx := context.TODO()
now := time.Now()
time.Sleep(time.Millisecond)
rc := vfsCurrentTime64(ctx, mem.mod, 0, 4)
rc := vfsCurrentTime64(ctx, mod, 0, 4)
if rc != 0 {
t.Fatal("returned", rc)
}
day, nsec := julianday.Date(now)
want := day*86_400_000 + nsec/1_000_000
if got := mem.readUint64(4); float32(got) != float32(want) {
if got := util.ReadUint64(mod, 4); float32(got) != float32(want) {
t.Errorf("got %v, want %v", got, want)
}
}
func Test_vfsFullPathname(t *testing.T) {
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, ".")
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 4, ".")
ctx := context.TODO()
rc := vfsFullPathname(ctx, mem.mod, 0, 4, 0, 8)
if rc != uint32(CANTOPEN_FULLPATH) {
t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH)
rc := vfsFullPathname(ctx, mod, 0, 4, 0, 8)
if rc != _CANTOPEN_FULLPATH {
t.Errorf("returned %d, want %d", rc, _CANTOPEN_FULLPATH)
}
rc = vfsFullPathname(ctx, mem.mod, 0, 4, _MAX_PATHNAME, 8)
rc = vfsFullPathname(ctx, mod, 0, 4, _MAX_PATHNAME, 8)
if rc != _OK {
t.Fatal("returned", rc)
}
want, _ := filepath.Abs(".")
if got := mem.readString(8, _MAX_PATHNAME); got != want {
if got := util.ReadString(mod, 8, _MAX_PATHNAME); got != want {
t.Errorf("got %v, want %v", got, want)
}
}
@@ -151,11 +146,11 @@ func Test_vfsDelete(t *testing.T) {
}
file.Close()
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(4, name)
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 4, name)
ctx := context.TODO()
rc := vfsDelete(ctx, mem.mod, 0, 4, 1)
rc := vfsDelete(ctx, mod, 0, 4, 1)
if rc != _OK {
t.Fatal("returned", rc)
}
@@ -164,8 +159,8 @@ func Test_vfsDelete(t *testing.T) {
t.Fatal("did not delete the file")
}
rc = vfsDelete(ctx, mem.mod, 0, 4, 1)
if rc != _OK {
rc = vfsDelete(ctx, mod, 0, 4, 1)
if rc != _IOERR_DELETE_NOENT {
t.Fatal("returned", rc)
}
}
@@ -182,99 +177,99 @@ func Test_vfsAccess(t *testing.T) {
t.Fatal(err)
}
mem := newMemory(128 + _MAX_PATHNAME)
mem.writeString(8, dir)
mod := util.NewMockModule(128 + _MAX_PATHNAME)
util.WriteString(mod, 8, dir)
ctx := context.TODO()
rc := vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_EXISTS, 4)
rc := vfsAccess(ctx, mod, 0, 8, _ACCESS_EXISTS, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("directory did not exist")
}
rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4)
rc = vfsAccess(ctx, mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 1 {
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("can't access directory")
}
mem.writeString(8, file)
rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4)
util.WriteString(mod, 8, file)
rc = vfsAccess(ctx, mod, 0, 8, _ACCESS_READWRITE, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(4); got != 0 {
if got := util.ReadUint32(mod, 4); got != 0 {
t.Error("can access file")
}
}
func Test_vfsFile(t *testing.T) {
mem := newMemory(128)
ctx, vfs := vfsContext(context.TODO())
mod := util.NewMockModule(128)
ctx, vfs := Context(context.TODO())
defer vfs.Close()
// Open a temporary file.
rc := vfsOpen(ctx, mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0)
rc := vfsOpen(ctx, mod, 0, 0, 4, _OPEN_CREATE|_OPEN_EXCLUSIVE|_OPEN_READWRITE|_OPEN_DELETEONCLOSE, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Write stuff.
text := "Hello world!"
mem.writeString(16, text)
rc = vfsWrite(ctx, mem.mod, 4, 16, uint32(len(text)), 0)
util.WriteString(mod, 16, text)
rc = vfsWrite(ctx, mod, 4, 16, uint32(len(text)), 0)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(ctx, mem.mod, 4, 16)
rc = vfsFileSize(ctx, mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != uint32(len(text)) {
if got := util.ReadUint32(mod, 16); got != uint32(len(text)) {
t.Errorf("got %d", got)
}
// Partial read at offset.
rc = vfsRead(ctx, mem.mod, 4, 16, uint32(len(text)), 4)
if rc != uint32(IOERR_SHORT_READ) {
rc = vfsRead(ctx, mod, 4, 16, uint32(len(text)), 4)
if rc != _IOERR_SHORT_READ {
t.Fatal("returned", rc)
}
if got := mem.readString(16, 64); got != text[4:] {
if got := util.ReadString(mod, 16, 64); got != text[4:] {
t.Errorf("got %q", got)
}
// Truncate the file.
rc = vfsTruncate(ctx, mem.mod, 4, 4)
rc = vfsTruncate(ctx, mod, 4, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
// Check file size.
rc = vfsFileSize(ctx, mem.mod, 4, 16)
rc = vfsFileSize(ctx, mod, 4, 16)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(16); got != 4 {
if got := util.ReadUint32(mod, 16); got != 4 {
t.Errorf("got %d", got)
}
// Read at offset.
rc = vfsRead(ctx, mem.mod, 4, 32, 4, 0)
rc = vfsRead(ctx, mod, 4, 32, 4, 0)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readString(32, 64); got != text[:4] {
if got := util.ReadString(mod, 32, 64); got != text[:4] {
t.Errorf("got %q", got)
}
// Close the file.
rc = vfsClose(ctx, mem.mod, 4)
rc = vfsClose(ctx, mod, 4)
if rc != _OK {
t.Fatal("returned", rc)
}

114
mem.go
View File

@@ -1,114 +0,0 @@
package sqlite3
import (
"bytes"
"math"
"github.com/tetratelabs/wazero/api"
)
type memory struct {
mod api.Module
}
func (m memory) view(ptr uint32, size uint64) []byte {
if ptr == 0 {
panic(nilErr)
}
if size > math.MaxUint32 {
panic(rangeErr)
}
buf, ok := m.mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(rangeErr)
}
return buf
}
func (m memory) readUint32(ptr uint32) uint32 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint32Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint32(ptr, v uint32) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint32Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readUint64(ptr uint32) uint64 {
if ptr == 0 {
panic(nilErr)
}
v, ok := m.mod.Memory().ReadUint64Le(ptr)
if !ok {
panic(rangeErr)
}
return v
}
func (m memory) writeUint64(ptr uint32, v uint64) {
if ptr == 0 {
panic(nilErr)
}
ok := m.mod.Memory().WriteUint64Le(ptr, v)
if !ok {
panic(rangeErr)
}
}
func (m memory) readFloat64(ptr uint32) float64 {
return math.Float64frombits(m.readUint64(ptr))
}
func (m memory) writeFloat64(ptr uint32, v float64) {
m.writeUint64(ptr, math.Float64bits(v))
}
func (m memory) readString(ptr, maxlen uint32) string {
if ptr == 0 {
panic(nilErr)
}
switch maxlen {
case 0:
return ""
case math.MaxUint32:
// avoid overflow
default:
maxlen = maxlen + 1
}
mem := m.mod.Memory()
buf, ok := mem.Read(ptr, maxlen)
if !ok {
buf, ok = mem.Read(ptr, mem.Size()-ptr)
if !ok {
panic(rangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(noNulErr)
} else {
return string(buf[:i])
}
}
func (m memory) writeBytes(ptr uint32, b []byte) {
buf := m.view(ptr, uint64(len(b)))
copy(buf, b)
}
func (m memory) writeString(ptr uint32, s string) {
buf := m.view(ptr, uint64(len(s)+1))
buf[len(s)] = 0
copy(buf, s)
}

View File

@@ -1,90 +0,0 @@
package sqlite3
import (
"math"
"testing"
)
func Test_memory_view_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(0, 8)
t.Error("want panic")
}
func Test_memory_view_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(126, 8)
t.Error("want panic")
}
func Test_memory_view_overflow(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.view(1, math.MaxInt64)
t.Error("want panic")
}
func Test_memory_readUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(0)
t.Error("want panic")
}
func Test_memory_readUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint32(126)
t.Error("want panic")
}
func Test_memory_readUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(0)
t.Error("want panic")
}
func Test_memory_readUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readUint64(126)
t.Error("want panic")
}
func Test_memory_writeUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint32(126, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(0, 1)
t.Error("want panic")
}
func Test_memory_writeUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.writeUint64(126, 1)
t.Error("want panic")
}
func Test_memory_readString_range(t *testing.T) {
defer func() { _ = recover() }()
mem := newMemory(128)
mem.readString(130, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -3,15 +3,13 @@ package sqlite3
import (
"context"
"crypto/rand"
"io"
"math"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/internal/vfs"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
@@ -28,11 +26,10 @@ var (
)
var sqlite3 struct {
once sync.Once
runtime wazero.Runtime
compiled wazero.CompiledModule
instances atomic.Uint64
err error
runtime wazero.Runtime
compiled wazero.CompiledModule
err error
once sync.Once
}
func instantiateModule() (*module, error) {
@@ -43,12 +40,7 @@ func instantiateModule() (*module, error) {
return nil, sqlite3.err
}
name := "sqlite3-" + strconv.FormatUint(sqlite3.instances.Add(1), 10)
cfg := wazero.NewModuleConfig().WithName(name).
WithSysWalltime().WithSysNanotime().WithSysNanosleep().
WithOsyield(runtime.Gosched).
WithRandSource(rand.Reader)
cfg := wazero.NewModuleConfig()
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
if err != nil {
@@ -60,7 +52,12 @@ func instantiateModule() (*module, error) {
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
vfsInstantiate(ctx, sqlite3.runtime)
env := vfs.Export(sqlite3.runtime.NewHostModuleBuilder("env"))
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
return
}
bin := Binary
if bin == nil && Path != "" {
@@ -70,7 +67,7 @@ func compileModule() {
}
}
if bin == nil {
sqlite3.err = binaryErr
sqlite3.err = util.BinaryErr
return
}
@@ -79,38 +76,39 @@ func compileModule() {
type module struct {
ctx context.Context
mem memory
api sqliteAPI
mod api.Module
vfs io.Closer
api sqliteAPI
arg [8]uint64
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m.mem = memory{mod}
m.ctx, m.vfs = vfsContext(context.Background())
m.mod = mod
m.ctx, m.vfs = vfs.Context(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
if f == nil {
err = noFuncErr + errorString(name)
err = util.NoFuncErr + util.ErrorString(name)
return nil
}
return f
}
getVal := func(name string) uint32 {
global := mod.ExportedGlobal(name)
if global == nil {
err = noGlobalErr + errorString(name)
g := mod.ExportedGlobal(name)
if g == nil {
err = util.NoGlobalErr + util.ErrorString(name)
return 0
}
return m.mem.readUint32(uint32(global.Get()))
return util.ReadUint32(mod, uint32(g.Get()))
}
m.api = sqliteAPI{
free: getFun("free"),
malloc: getFun("malloc"),
destructor: uint64(getVal("malloc_destructor")),
destructor: getVal("malloc_destructor"),
errcode: getFun("sqlite3_errcode"),
errstr: getFun("sqlite3_errstr"),
errmsg: getFun("sqlite3_errmsg"),
@@ -164,7 +162,7 @@ func newModule(mod api.Module) (m *module, err error) {
}
func (m *module) close() error {
err := m.mem.mod.Close(m.ctx)
err := m.mod.Close(m.ctx)
m.vfs.Close()
return err
}
@@ -177,19 +175,19 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(oomErr)
panic(util.OOMErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = m.mem.readString(uint32(r[0]), _MAX_STRING)
err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING)
err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
}
if sql != nil {
@@ -207,13 +205,14 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
r, err := fn.Call(m.ctx, params...)
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
if err != nil {
// The module closed or panicked; release resources.
m.vfs.Close()
panic(err)
}
return r
return m.arg[:]
}
func (m *module) free(ptr uint32) {
@@ -225,12 +224,12 @@ func (m *module) free(ptr uint32) {
func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(oomErr)
panic(util.OOMErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
if ptr == 0 && size != 0 {
panic(oomErr)
panic(util.OOMErr)
}
return ptr
}
@@ -240,13 +239,13 @@ func (m *module) newBytes(b []byte) uint32 {
return 0
}
ptr := m.new(uint64(len(b)))
m.mem.writeBytes(ptr, b)
util.WriteBytes(m.mod, ptr, b)
return ptr
}
func (m *module) newString(s string) uint32 {
ptr := m.new(uint64(len(s) + 1))
m.mem.writeString(ptr, s)
util.WriteString(m.mod, ptr, s)
return ptr
}
@@ -260,10 +259,10 @@ func (m *module) newArena(size uint64) arena {
type arena struct {
m *module
ptrs []uint32
base uint32
next uint32
size uint32
ptrs []uint32
}
func (a *arena) free() {
@@ -294,16 +293,24 @@ func (a *arena) new(size uint64) uint32 {
return ptr
}
func (a *arena) bytes(b []byte) uint32 {
if b == nil {
return 0
}
ptr := a.new(uint64(len(b)))
util.WriteBytes(a.m.mod, ptr, b)
return ptr
}
func (a *arena) string(s string) uint32 {
ptr := a.new(uint64(len(s) + 1))
a.m.mem.writeString(ptr, s)
util.WriteString(a.m.mod, ptr, s)
return ptr
}
type sqliteAPI struct {
free api.Function
malloc api.Function
destructor uint64
errcode api.Function
errstr api.Function
errmsg api.Function
@@ -348,5 +355,6 @@ type sqliteAPI struct {
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
destructor uint32
interrupt uint32
}

View File

@@ -4,8 +4,14 @@ import (
"bytes"
"math"
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
)
func init() {
Path = "./embed/sqlite3.wasm"
}
func TestConn_error_OOM(t *testing.T) {
t.Parallel()
@@ -20,14 +26,14 @@ func TestConn_error_OOM(t *testing.T) {
t.Error("want panic")
}
func TestConn_call_nil(t *testing.T) {
func TestConn_call_closed(t *testing.T) {
t.Parallel()
m, err := instantiateModule()
if err != nil {
t.Fatal(err)
}
defer m.close()
m.close()
defer func() { _ = recover() }()
m.call(m.api.free)
@@ -43,14 +49,18 @@ func TestConn_new(t *testing.T) {
}
defer m.close()
testOOM := func(size uint64) {
t.Run("MaxUint32", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(size)
m.new(math.MaxUint32)
t.Error("want panic")
}
})
testOOM(math.MaxUint32)
testOOM(_MAX_ALLOCATION_SIZE)
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
defer func() { _ = recover() }()
m.new(_MAX_ALLOCATION_SIZE)
m.new(_MAX_ALLOCATION_SIZE)
t.Error("want panic")
})
}
func TestConn_newArena(t *testing.T) {
@@ -71,7 +81,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := m.mem.readString(ptr, math.MaxUint32); got != title {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
t.Errorf("got %q, want %q", got, title)
}
@@ -80,7 +90,7 @@ func TestConn_newArena(t *testing.T) {
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := m.mem.readString(ptr, math.MaxUint32); got != body {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
arena.free()
@@ -107,7 +117,7 @@ func TestConn_newBytes(t *testing.T) {
}
want := buf
if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) {
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -133,7 +143,7 @@ func TestConn_newString(t *testing.T) {
}
want := str + "\000"
if got := m.mem.view(ptr, uint64(len(want))); string(got) != want {
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
@@ -159,22 +169,22 @@ func TestConn_getString(t *testing.T) {
}
want := "sqlite3"
if got := m.mem.readString(ptr, math.MaxUint32); got != want {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
t.Errorf("got %q, want %q", got, want)
}
if got := m.mem.readString(ptr, 0); got != "" {
if got := util.ReadString(m.mod, ptr, 0); got != "" {
t.Errorf("got %q, want empty", got)
}
func() {
defer func() { _ = recover() }()
m.mem.readString(ptr, uint32(len(want)/2))
util.ReadString(m.mod, ptr, uint32(len(want)/2))
t.Error("want panic")
}()
func() {
defer func() { _ = recover() }()
m.mem.readString(0, math.MaxUint32)
util.ReadString(m.mod, 0, math.MaxUint32)
t.Error("want panic")
}()
}

View File

@@ -1,9 +0,0 @@
#include <stddef.h>
#include "main.c"
#include "os.c"
#include "qsort.c"
#include "sqlite3.c"
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);

View File

@@ -1,13 +1,31 @@
#!/usr/bin/env bash
set -eo pipefail
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "sqlite3.c" ]; then
url="https://sqlite.org/2023/sqlite-amalgamation-3410000.zip"
curl "$url" > sqlite.zip
unzip -d . sqlite.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
rm sqlite.zip
fi
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3410200.zip"
unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/uint.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/uuid.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/base64.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/regexp.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/series.c"
cd ~-
cd ../internal/vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/config02.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/crash01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/crash02.subtest"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/multiwrite01.test"
cd ~-
cd ../internal/vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/test/speedtest1.c"
cd ~-

1
sqlite3/ext/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.c

View File

@@ -1,8 +1,7 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
clang-format -i \
main.c \
os.c \
qsort.c \
amalg.c
shopt -s extglob
clang-format -i !(sqlite3*).@(c|h)

View File

@@ -1,14 +1,32 @@
#include <stdbool.h>
#include <stddef.h>
#include "sqlite3.h"
#include "sqlite3.c"
//
#include "os.c"
//
#include "ext/base64.c"
#include "ext/decimal.c"
#include "ext/regexp.c"
#include "ext/series.c"
#include "ext/uint.c"
#include "ext/uuid.c"
#include "time.c"
int main() {
int rc = sqlite3_initialize();
if (rc != SQLITE_OK) return 1;
}
sqlite3_vfs *os_vfs();
sqlite3_destructor_type malloc_destructor = &free;
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
int sqlite3_os_init() {
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
}
__attribute__((constructor)) void init() {
sqlite3_initialize();
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
sqlite3_auto_extension((void (*)(void))sqlite3_decimal_init);
sqlite3_auto_extension((void (*)(void))sqlite3_regexp_init);
sqlite3_auto_extension((void (*)(void))sqlite3_series_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uint_init);
sqlite3_auto_extension((void (*)(void))sqlite3_uuid_init);
sqlite3_auto_extension((void (*)(void))sqlite3_time_init);
}

View File

@@ -2,7 +2,7 @@
#include "sqlite3.h"
int os_localtime(sqlite3_int64, struct tm *);
int os_localtime(struct tm *, sqlite3_int64);
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
int os_sleep(sqlite3_vfs *, int microseconds);
@@ -17,37 +17,32 @@ int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
struct os_file {
sqlite3_file base;
int id;
int lock;
int handle;
};
static_assert(offsetof(struct os_file, handle) == 4, "Unexpected offset");
int os_close(sqlite3_file *);
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
int os_truncate(sqlite3_file *, sqlite3_int64 size);
int os_sync(sqlite3_file *, int flags);
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
int os_file_control(sqlite3_file *pFile, int op, void *pArg);
int os_file_control(sqlite3_file *, int op, void *pArg);
int os_sector_size(sqlite3_file *file);
int os_device_characteristics(sqlite3_file *file);
int os_lock(sqlite3_file *pFile, int eLock);
int os_unlock(sqlite3_file *pFile, int eLock);
int os_check_reserved_lock(sqlite3_file *pFile, int *pResOut);
int os_lock(sqlite3_file *, int eLock);
int os_unlock(sqlite3_file *, int eLock);
int os_check_reserved_lock(sqlite3_file *, int *pResOut);
static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; }
static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) {
*pResOut = 0;
return SQLITE_OK;
}
static int no_file_control(sqlite3_file *pFile, int op, void *pArg) {
return SQLITE_NOTFOUND;
}
static int no_sector_size(sqlite3_file *pFile) { return 0; }
static int no_device_characteristics(sqlite3_file *pFile) { return 0; }
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime((sqlite3_int64)*pTime, pTm);
static int os_file_control_w(sqlite3_file *file, int op, void *pArg) {
struct os_file *pFile = (struct os_file *)file;
if (op == SQLITE_FCNTL_VFSNAME) {
*(char **)pArg = sqlite3_mprintf("%s", "os");
return SQLITE_OK;
}
return os_file_control(file, op, pArg);
}
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
@@ -63,12 +58,18 @@ static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
.xLock = os_lock,
.xUnlock = os_unlock,
.xCheckReservedLock = os_check_reserved_lock,
.xFileControl = no_file_control,
.xDeviceCharacteristics = no_device_characteristics,
.xFileControl = os_file_control_w,
.xSectorSize = os_sector_size,
.xDeviceCharacteristics = os_device_characteristics,
};
memset(file, 0, sizeof(struct os_file));
int rc = os_open(vfs, zName, file, flags, pOutFlags);
file->pMethods = (char)rc == SQLITE_OK ? &os_io : NULL;
return rc;
if (rc) {
return rc;
}
file->pMethods = &os_io;
return SQLITE_OK;
}
sqlite3_vfs *os_vfs() {
@@ -90,3 +91,7 @@ sqlite3_vfs *os_vfs() {
};
return &os_vfs;
}
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return os_localtime(pTm, (sqlite3_int64)*pTime);
}

View File

@@ -1,14 +0,0 @@
#include <stddef.h>
void qsort_r(void *, size_t, size_t,
int (*)(const void *, const void *, void *), void *);
typedef int (*cmpfun)(const void *, const void *);
static int wrapper_cmp(const void *v1, const void *v2, void *cmp) {
return ((cmpfun)cmp)(v1, v2);
}
void qsort(void *base, size_t nel, size_t width, cmpfun cmp) {
qsort_r(base, nel, width, wrapper_cmp, cmp);
}

View File

@@ -28,8 +28,11 @@
#define SQLITE_OMIT_AUTOINIT
#define SQLITE_USE_ALLOCA
// Other Options
// #define SQLITE_ALLOW_URI_AUTHORITY
// Because WASM does not support shared memory,
// SQLite disables it for WASM builds.
// SQLite disables WAL for WASM builds.
// We set the default locking mode to EXCLUSIVE instead.
// https://www.sqlite.org/wal.html#noshm
#undef SQLITE_OMIT_WAL
@@ -48,15 +51,9 @@
#define SQLITE_ENABLE_RTREE 1
#define SQLITE_ENABLE_GEOPOLY 1
// Snapshot
// #define SQLITE_ENABLE_SNAPSHOT 1
// Session Extension
// #define SQLITE_ENABLE_SESSION 1
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
// Resumable Bulk Update Extension
// #define SQLITE_ENABLE_RBU 1
// Implemented in Go.
int localtime_s(struct tm *const pTm, time_t const *const pTime);

32
sqlite3/time.c Normal file
View File

@@ -0,0 +1,32 @@
#include <string.h>
#include "sqlite3.h"
static int time_collation(void *pArg, int nKey1, const void *pKey1, int nKey2,
const void *pKey2) {
// Remove a Z suffix if one key is no longer than the other.
// A Z suffix collates before any character but after the empty string.
// This avoids making different keys equal.
const int nK1 = nKey1;
const int nK2 = nKey2;
const char *pK1 = (const char *)pKey1;
const char *pK2 = (const char *)pKey2;
if (nK1 && nK1 <= nK2 && pK1[nK1 - 1] == 'Z') {
nKey1--;
}
if (nK2 && nK2 <= nK1 && pK2[nK2 - 1] == 'Z') {
nKey2--;
}
int n = nKey1 < nKey2 ? nKey1 : nKey2;
int rc = memcmp(pKey1, pKey2, n);
if (rc == 0) {
rc = nKey1 - nKey2;
}
return rc;
}
int sqlite3_time_init(sqlite3 *db, char **pzErrMsg,
const sqlite3_api_routines *pApi) {
return sqlite3_create_collation(db, "time", SQLITE_UTF8, 0, time_collation);
}

89
stmt.go
View File

@@ -3,6 +3,8 @@ package sqlite3
import (
"math"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
)
// Stmt is a prepared statement object.
@@ -10,8 +12,8 @@ import (
// https://www.sqlite.org/c3ref/stmt.html
type Stmt struct {
c *Conn
handle uint32
err error
handle uint32
}
// Close destroys the prepared statement object.
@@ -119,7 +121,7 @@ func (s *Stmt) BindName(param int) string {
if ptr == 0 {
return ""
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// BindBool binds a bool to the prepared statement.
@@ -172,7 +174,7 @@ func (s *Stmt) BindText(param int, value string) error {
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor, _UTF8)
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
}
@@ -186,7 +188,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
r := s.c.call(s.c.api.bindBlob,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
s.c.api.destructor)
uint64(s.c.api.destructor))
return s.c.error(r[0])
}
@@ -215,6 +217,9 @@ func (s *Stmt) BindNull(param int) error {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
if format == TimeFormatDefault {
return s.bindRFC3339Nano(param, value)
}
switch v := format.Encode(value).(type) {
case string:
s.BindText(param, v)
@@ -223,11 +228,25 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error {
case float64:
s.BindFloat(param, v)
default:
panic(assertErr())
panic(util.AssertErr())
}
return nil
}
func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
const maxlen = uint64(len(time.RFC3339Nano))
ptr := s.c.new(maxlen)
buf := util.View(s.c.mod, ptr, maxlen)
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -247,9 +266,9 @@ func (s *Stmt) ColumnName(col int) string {
ptr := uint32(r[0])
if ptr == 0 {
panic(oomErr)
panic(util.OOMErr)
}
return s.c.mem.readString(ptr, _MAX_STRING)
return util.ReadString(s.c.mod, ptr, _MAX_STRING)
}
// ColumnType returns the initial [Datatype] of the result column.
@@ -320,7 +339,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
case NULL:
return time.Time{}
default:
panic(assertErr())
panic(util.AssertErr())
}
t, err := format.Decode(v)
if err != nil {
@@ -334,21 +353,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnText(col int) string {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return ""
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
mem := s.c.mem.view(ptr, r[0])
return string(mem)
return string(s.ColumnRawText(col))
}
// ColumnBlob appends to buf and returns
@@ -357,6 +362,39 @@ func (s *Stmt) ColumnText(col int) string {
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
return append(buf, s.ColumnRawBlob(col)...)
}
// ColumnRawText returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r[0])
}
// ColumnRawBlob returns the value of the result column as a []byte.
// The []byte is owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
@@ -364,14 +402,13 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
return buf[0:0]
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
mem := s.c.mem.view(ptr, r[0])
return append(buf[0:0], mem...)
return util.View(s.c.mod, ptr, r[0])
}
// Return true if stmt is an empty SQL statement.

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"errors"
"fmt"
"hash/adler32"
"io"
"testing"
@@ -48,17 +49,17 @@ func TestBlob(t *testing.T) {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:size/2]))
_, err = blob.Write(data[:size/2])
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(blob, bytes.NewReader(data[:]))
if !errors.Is(err, sqlite3.ERROR) {
t.Fatal("want error")
n, err := blob.Write(data[:])
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
_, err = io.Copy(blob, bytes.NewReader(data[size/2:size]))
_, err = blob.Write(data[size/2 : size])
if err != nil {
t.Fatal(err)
}
@@ -87,6 +88,126 @@ func TestBlob(t *testing.T) {
}
}
func TestBlob_large(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1000000))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
size := blob.Size()
if size != 1000000 {
t.Errorf("got %d, want 1000000", size)
}
hash := adler32.New()
_, err = io.CopyN(blob, io.TeeReader(rand.Reader, hash), 1000000)
if err != nil {
t.Fatal(err)
}
_, err = blob.Seek(0, io.SeekStart)
if err != nil {
t.Fatal(err)
}
want := hash.Sum32()
hash.Reset()
_, err = io.Copy(hash, blob)
if err != nil {
t.Fatal(err)
}
if got := hash.Sum32(); got != want {
t.Fatalf("got %d, want %d", got, want)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_overflow(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO test VALUES (zeroblob(1024))`)
if err != nil {
t.Fatal(err)
}
blob, err := db.OpenBlob("main", "test", "col", db.LastInsertRowID(), true)
if err != nil {
t.Fatal(err)
}
defer blob.Close()
n, err := blob.ReadFrom(rand.Reader)
if n != 1024 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
n, err = blob.ReadFrom(rand.Reader)
if n != 0 || !errors.Is(err, sqlite3.ERROR) {
t.Fatalf("got (%d, %v), want (0, ERROR)", n, err)
}
_, err = blob.Seek(-128, io.SeekEnd)
if err != nil {
t.Fatal(err)
}
n, err = blob.WriteTo(io.Discard)
if n != 128 || err != nil {
t.Fatalf("got (%d, %v), want (128, nil)", n, err)
}
n, err = blob.WriteTo(io.Discard)
if n != 0 || err != nil {
t.Fatalf("got (%d, %v), want (0, nil)", n, err)
}
if err := blob.Close(); err != nil {
t.Fatal(err)
}
if err := db.Close(); err != nil {
t.Fatal(err)
}
}
func TestBlob_invalid(t *testing.T) {
t.Parallel()

View File

@@ -126,7 +126,7 @@ func testTxQuery(t params) {
if r.Err() != nil {
t.Fatal(err)
}
t.Fatal("expected one rows")
t.Fatal("expected one row")
}
var name string

View File

@@ -13,7 +13,7 @@ import (
func TestConn_Open_dir(t *testing.T) {
t.Parallel()
_, err := sqlite3.Open(".")
_, err := sqlite3.OpenFlags(".", 0)
if err == nil {
t.Fatal("want error")
}
@@ -202,59 +202,3 @@ func TestConn_Prepare_invalid(t *testing.T) {
t.Error("got message:", got)
}
}
func TestConn_MustPrepare_empty(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(``)
t.Error("want panic")
}
func TestConn_MustPrepare_tail(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT 1; -- HERE`)
t.Error("want panic")
}
func TestConn_MustPrepare_invalid(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.MustPrepare(`SELECT`)
t.Error("want panic")
}
func TestConn_Pragma(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
defer func() { _ = recover() }()
_ = db.Pragma("encoding=''")
t.Error("want panic")
}

77
tests/ext_test.go Normal file
View File

@@ -0,0 +1,77 @@
package tests
import (
"testing"
"github.com/ncruces/go-sqlite3"
)
func Test_base64(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
// base64
stmt, _, err := db.Prepare(`SELECT base64('TWFueSBoYW5kcyBtYWtlIGxpZ2h0IHdvcmsu')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if got := stmt.ColumnText(0); got != "Many hands make light work." {
t.Errorf("got %q", got)
}
}
func Test_decimal(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT decimal_add(decimal('0.1'), decimal('0.2')) = decimal('0.3')`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}
func Test_uint(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT 'z2' < 'z11' COLLATE UINT`)
if err != nil {
t.Error(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("expected one row")
}
if !stmt.ColumnBool(0) {
t.Error("want true")
}
}

View File

@@ -1,26 +0,0 @@
#!/usr/bin/env bash
set -eo pipefail
cd -P -- "$(dirname -- "$0")"
if [ ! -f "mptest.c" ]; then
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/mptest.c"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/config01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/config02.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/crash01.test"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/crash02.subtest"
curl -sOL "https://github.com/sqlite/sqlite/raw/master/mptest/multiwrite01.test"
fi
zig cc --target=wasm32-wasi -flto -g0 -Os \
-o mptest.wasm main.c test.c \
-I../../../sqlite3 \
-mmutable-globals \
-mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \
-D_HAVE_SQLITE_CONFIG_H \
-DSQLITE_DEFAULT_SYNCHRONOUS=0 \
-DSQLITE_DEFAULT_LOCKING_MODE=0 \
-DHAVE_USLEEP -DSQLITE_NO_SYNC \
-DSQLITE_THREADSAFE=0 -DSQLITE_OMIT_LOAD_EXTENSION \
-D_WASI_EMULATED_GETPID -lwasi-emulated-getpid

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3960b873a7dab969a66f7859d491cec0dd4e6c0c9f83eab449fb15ec5ebdfd8f
size 1077281

View File

@@ -1,5 +0,0 @@
#define unlink dont_unlink
#include "mptest.c"
int dont_unlink(const char *pathname) { return 0; }

View File

@@ -40,7 +40,7 @@ func TestTimeFormat_Decode(t *testing.T) {
t.Parallel()
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
reftime := time.Date(2000, 1, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
refnodate := time.Date(2000, 01, 1, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
tests := []struct {
fmt sqlite3.TimeFormat
@@ -89,14 +89,14 @@ func TestTimeFormat_Decode(t *testing.T) {
{sqlite3.TimeFormatAuto, "1381134199120000", reference, 0, false},
{sqlite3.TimeFormatAuto, "1381134199120000000", reference, 0, false},
{sqlite3.TimeFormatAuto, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", reftime, 0, false},
{sqlite3.TimeFormatAuto, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormatAuto, "abc", time.Time{}, 0, true},
{sqlite3.TimeFormatAuto, false, time.Time{}, 0, true},
{sqlite3.TimeFormat3, "2013-10-07 04:23:19.12-04:00", reference, 0, false},
{sqlite3.TimeFormat3, "2013-10-07 08:23:19.12", reference, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", reftime, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", reftime, 0, false},
{sqlite3.TimeFormat9, "04:23:19.12-04:00", refnodate, 0, false},
{sqlite3.TimeFormat9, "08:23:19.12", refnodate, 0, false},
{sqlite3.TimeFormat3, false, time.Time{}, 0, true},
{sqlite3.TimeFormat9, false, time.Time{}, 0, true},
@@ -118,3 +118,52 @@ func TestTimeFormat_Decode(t *testing.T) {
})
}
}
func TestDB_timeCollation(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS times (tstamp COLLATE TIME)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO times VALUES (?), (?), (?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
stmt.BindTime(1, time.Unix(0, 0).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(2, time.Unix(0, -1).UTC(), sqlite3.TimeFormatDefault)
stmt.BindTime(3, time.Unix(0, +1).UTC(), sqlite3.TimeFormatDefault)
stmt.Step()
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
stmt, _, err = db.Prepare(`SELECT tstamp FROM times ORDER BY tstamp`)
if err != nil {
t.Fatal(err)
}
var t0 time.Time
for stmt.Step() {
t1 := stmt.ColumnTime(0, sqlite3.TimeFormatAuto)
if t0.After(t1) {
t.Errorf("got %v after %v", t0, t1)
}
t0 = t1
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -185,10 +185,10 @@ func TestConn_Transaction_interrupt(t *testing.T) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
var nilErr error
tx.End(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -210,6 +210,33 @@ func TestConn_Transaction_interrupt(t *testing.T) {
}
}
func TestConn_Transaction_interrupted(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
cancel()
tx := db.Begin()
err = tx.Commit()
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
err = nil
tx.End(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
}
func TestConn_Transaction_rollback(t *testing.T) {
t.Parallel()
@@ -286,7 +313,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
}
insert := func(succeed bool) (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -344,7 +371,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
}
panics := func() (err error) {
defer db.Savepoint()(&err)
defer db.Savepoint().Release(&err)
err = db.Exec(`INSERT INTO test VALUES ('hello')`)
if err != nil {
@@ -395,12 +422,12 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}
@@ -408,19 +435,19 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx)
release1 := db.Savepoint()
savept1 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (2)`)
if err != nil {
t.Fatal(err)
}
release2 := db.Savepoint()
savept2 := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (3)`)
if err != nil {
t.Fatal(err)
}
cancel()
db.Savepoint()(&err)
db.Savepoint().Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
@@ -431,15 +458,15 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
}
err = context.Canceled
release2(&err)
savept2.Release(&err)
if err != context.Canceled {
t.Fatal(err)
}
var nilErr error
release1(&nilErr)
if !errors.Is(nilErr, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", nilErr)
err = nil
savept1.Release(&err)
if !errors.Is(err, sqlite3.INTERRUPT) {
t.Errorf("got %v, want sqlite3.INTERRUPT", err)
}
db.SetInterrupt(context.Background())
@@ -475,7 +502,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
t.Fatal(err)
}
release := db.Savepoint()
savept := db.Savepoint()
err = db.Exec(`INSERT INTO test VALUES (1)`)
if err != nil {
t.Fatal(err)
@@ -484,7 +511,7 @@ func TestConn_Savepoint_rollback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
release(&err)
savept.Release(&err)
if err != nil {
t.Fatal(err)
}

40
time.go
View File

@@ -6,6 +6,7 @@ import (
"strings"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
)
@@ -62,13 +63,18 @@ const (
// [TimeFormatDefault] and [TimeFormatAuto] encode using [time.RFC3339Nano],
// with nanosecond accuracy, and preserving any timezone offset.
//
// This is the format used by the database/sql driver:
// [database/sql.Row.Scan] is able to decode as [time.Time]
// This is the format used by the [database/sql] driver:
// [database/sql.Row.Scan] will decode as [time.Time]
// values encoded with [time.RFC3339Nano].
//
// Time values encoded with [time.RFC3339Nano] cannot be sorted as strings
// to produce a time-ordered sequence.
// Use [TimeFormat7] for time-ordered encoding.
//
// Assuming that the time zones of the time values are the same (e.g., all in UTC),
// and expressed using the same string (e.g., all "Z" or all "+00:00"),
// use the TIME [collating sequence] to produce a time-ordered sequence.
//
// Otherwise, use [TimeFormat7] for time-ordered encoding.
//
// Formats [TimeFormat1] through [TimeFormat10]
// convert time values to UTC before encoding.
@@ -78,6 +84,8 @@ const (
// or an int64 for the other numeric formats.
//
// https://www.sqlite.org/lang_datefunc.html
//
// [collating sequence]: https://www.sqlite.org/datatype3.html#collating_sequences
func (f TimeFormat) Encode(t time.Time) any {
switch f {
// Numeric formats
@@ -123,9 +131,9 @@ func (f TimeFormat) Encode(t time.Time) any {
// [TimeFormatAuto] implements (and extends) the SQLite auto modifier.
// Julian day numbers are safe to use for historical dates,
// from 4712BC through 9999AD.
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds),
// are safe to use for current events, from 1980 through at least 2260.
// Unix timestamps before 1980 may be misinterpreted as julian day numbers,
// Unix timestamps (expressed in seconds, milliseconds, microseconds, or nanoseconds)
// are safe to use for current events, from at least 1980 through at least 2260.
// Unix timestamps before 1980 and after 9999 may be misinterpreted as julian day numbers,
// or have the wrong time unit.
//
// https://www.sqlite.org/lang_datefunc.html
@@ -141,7 +149,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return julianday.Time(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnix, TimeFormatUnixFrac:
@@ -160,7 +168,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(v, 0), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMilli:
@@ -177,7 +185,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMilli(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixMicro:
@@ -194,14 +202,14 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.UnixMicro(int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case TimeFormatUnixNano:
if s, ok := v.(string); ok {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
v = i
}
@@ -211,7 +219,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
case int64:
return time.Unix(0, int64(v)), nil
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
// Special formats
@@ -281,7 +289,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
}
return TimeFormatUnixNano.Decode(v)
default:
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
case
@@ -293,7 +301,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat7, TimeFormat7TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
return f.parseRelaxed(s)
@@ -303,7 +311,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
TimeFormat10, TimeFormat10TZ:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
t, err := f.parseRelaxed(s)
return t.AddDate(2000, 0, 0), err
@@ -311,7 +319,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) {
default:
s, ok := v.(string)
if !ok {
return time.Time{}, timeErr
return time.Time{}, util.TimeErr
}
if f == "" {
f = time.RFC3339Nano

154
tx.go
View File

@@ -4,10 +4,14 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"runtime"
"strconv"
)
// Tx is an in-progress database transaction.
//
// https://www.sqlite.org/lang_transaction.html
type Tx struct {
c *Conn
}
@@ -16,8 +20,9 @@ type Tx struct {
//
// https://www.sqlite.org/lang_transaction.html
func (c *Conn) Begin() Tx {
err := c.Exec(`BEGIN DEFERRED`)
if err != nil && !errors.Is(err, INTERRUPT) {
// BEGIN even if interrupted.
err := c.txExecInterrupted(`BEGIN DEFERRED`)
if err != nil {
panic(err)
}
return Tx{c}
@@ -64,21 +69,22 @@ func (tx Tx) End(errp *error) {
defer panic(recovered)
}
if tx.c.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
if tx.c.GetAutocommit() { // There is nothing to commit.
return
}
*errp = tx.Commit()
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
// Fall through to the error path.
}
// Error path.
if tx.c.GetAutocommit() { // There is nothing to rollback.
return
}
err := tx.Rollback()
if err != nil {
panic(err)
@@ -92,33 +98,28 @@ func (tx Tx) Commit() error {
return tx.c.Exec(`COMMIT`)
}
// Rollback rollsback the transaction.
// Rollback rolls back the transaction,
// even if the connection has been interrupted.
//
// https://www.sqlite.org/lang_transaction.html
func (tx Tx) Rollback() error {
// ROLLBACK even if the connection has been interrupted.
old := tx.c.SetInterrupt(context.Background())
defer tx.c.SetInterrupt(old)
return tx.c.Exec(`ROLLBACK`)
return tx.c.txExecInterrupted(`ROLLBACK`)
}
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a release func that will call either
// RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// // ... do work in the transaction
// }
// Savepoint is a marker within a transaction
// that allows for partial rollback.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
type Savepoint struct {
c *Conn
name string
}
// Savepoint establishes a new transaction savepoint.
//
// https://www.sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint {
name := "sqlite3.Savepoint"
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
@@ -127,52 +128,75 @@ func (c *Conn) Savepoint() (release func(*error)) {
name = frame.Function
}
}
// Names can be reused; this makes catching bugs more likely.
name += "#" + strconv.Itoa(int(rand.Int31()))
err := c.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
if errors.Is(err, INTERRUPT) {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
panic(err)
}
return Savepoint{c: c, name: name}
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
// Release releases the savepoint rolling back any changes
// if *error points to a non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// savept := conn.Savepoint()
// defer savept.Release(&err)
//
// // ... do work in the transaction
// }
func (s Savepoint) Release(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if c.GetAutocommit() {
// There is nothing to commit/rollback.
if *errp == nil && recovered == nil {
// Success path.
if s.c.GetAutocommit() { // There is nothing to commit.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil {
return
}
// Fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err := c.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = c.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
// Error path.
if s.c.GetAutocommit() { // There is nothing to rollback.
return
}
// ROLLBACK and RELEASE even if interrupted.
err := s.c.txExecInterrupted(fmt.Sprintf(`
ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil {
panic(err)
}
}
// Rollback rolls the transaction back to the savepoint,
// even if the connection has been interrupted.
// Rollback does not release the savepoint.
//
// https://www.sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted.
return s.c.txExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
}
func (c *Conn) txExecInterrupted(sql string) error {
err := c.Exec(sql)
if errors.Is(err, INTERRUPT) {
old := c.SetInterrupt(context.Background())
defer c.SetInterrupt(old)
err = c.Exec(sql)
}
return err
}

348
vfs.go
View File

@@ -1,348 +0,0 @@
package sqlite3
import (
"context"
"crypto/rand"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/sys"
)
func vfsInstantiate(ctx context.Context, r wazero.Runtime) {
wasi := r.NewHostModuleBuilder("wasi_snapshot_preview1")
wasi.NewFunctionBuilder().WithFunc(vfsExit).Export("proc_exit")
_, err := wasi.Instantiate(ctx)
if err != nil {
panic(err)
}
env := vfsNewEnvModuleBuilder(r)
_, err = env.Instantiate(ctx)
if err != nil {
panic(err)
}
}
func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder {
env := r.NewHostModuleBuilder("env")
env.NewFunctionBuilder().WithFunc(vfsLocaltime).Export("os_localtime")
env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("os_randomness")
env.NewFunctionBuilder().WithFunc(vfsSleep).Export("os_sleep")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("os_current_time")
env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("os_current_time_64")
env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("os_full_pathname")
env.NewFunctionBuilder().WithFunc(vfsDelete).Export("os_delete")
env.NewFunctionBuilder().WithFunc(vfsAccess).Export("os_access")
env.NewFunctionBuilder().WithFunc(vfsOpen).Export("os_open")
env.NewFunctionBuilder().WithFunc(vfsClose).Export("os_close")
env.NewFunctionBuilder().WithFunc(vfsRead).Export("os_read")
env.NewFunctionBuilder().WithFunc(vfsWrite).Export("os_write")
env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("os_truncate")
env.NewFunctionBuilder().WithFunc(vfsSync).Export("os_sync")
env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("os_file_size")
env.NewFunctionBuilder().WithFunc(vfsLock).Export("os_lock")
env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("os_unlock")
env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("os_check_reserved_lock")
env.NewFunctionBuilder().WithFunc(vfsFileControl).Export("os_file_control")
return env
}
// Poor man's namespaces.
const (
vfsOS vfsOSMethods = false
vfsFile vfsFileMethods = false
)
type (
vfsOSMethods bool
vfsFileMethods bool
)
type vfsKey struct{}
type vfsState struct {
files []*os.File
}
func vfsContext(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
func (vfs *vfsState) Close() error {
for _, f := range vfs.files {
if f != nil {
f.Close()
}
}
vfs.files = nil
return nil
}
func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) {
// Ensure other callers see the exit code.
_ = mod.CloseWithExitCode(ctx, exitCode)
// Prevent any code from executing after this function.
panic(sys.NewExitError(mod.Name(), exitCode))
}
func vfsLocaltime(ctx context.Context, mod api.Module, t uint64, pTm uint32) uint32 {
tm := time.Unix(int64(t), 0)
var isdst int
if tm.IsDST() {
isdst = 1
}
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
mem := memory{mod}
mem.writeUint32(pTm+0*ptrlen, uint32(tm.Second()))
mem.writeUint32(pTm+1*ptrlen, uint32(tm.Minute()))
mem.writeUint32(pTm+2*ptrlen, uint32(tm.Hour()))
mem.writeUint32(pTm+3*ptrlen, uint32(tm.Day()))
mem.writeUint32(pTm+4*ptrlen, uint32(tm.Month()-time.January))
mem.writeUint32(pTm+5*ptrlen, uint32(tm.Year()-1900))
mem.writeUint32(pTm+6*ptrlen, uint32(tm.Weekday()-time.Sunday))
mem.writeUint32(pTm+7*ptrlen, uint32(tm.YearDay()-1))
mem.writeUint32(pTm+8*ptrlen, uint32(isdst))
return _OK
}
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
mem := memory{mod}.view(zByte, uint64(nByte))
n, _ := rand.Reader.Read(mem)
return uint32(n)
}
func vfsSleep(ctx context.Context, pVfs, nMicro uint32) uint32 {
time.Sleep(time.Duration(nMicro) * time.Microsecond)
return _OK
}
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uint32 {
day := julianday.Float(time.Now())
memory{mod}.writeFloat64(prNow, day)
return _OK
}
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) uint32 {
day, nsec := julianday.Date(time.Now())
msec := day*86_400_000 + nsec/1_000_000
memory{mod}.writeUint64(piNow, uint64(msec))
return _OK
}
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) uint32 {
rel := memory{mod}.readString(zRelative, _MAX_PATHNAME)
abs, err := filepath.Abs(rel)
if err != nil {
return uint32(IOERR)
}
// Consider either using [filepath.EvalSymlinks] to canonicalize the path (as the Unix VFS does).
// Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did).
// This might be buggy on Windows (the Windows VFS doesn't try).
size := uint64(len(abs) + 1)
if size > uint64(nFull) {
return uint32(CANTOPEN_FULLPATH)
}
mem := memory{mod}.view(zFull, size)
mem[len(abs)] = 0
copy(mem, abs)
return _OK
}
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) uint32 {
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _OK
}
if err != nil {
return uint32(IOERR_DELETE)
}
if runtime.GOOS != "windows" && syncDir != 0 {
f, err := os.Open(filepath.Dir(path))
if err == nil {
err = f.Sync()
f.Close()
}
if err != nil {
return uint32(IOERR_DELETE)
}
}
return _OK
}
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 {
// Consider using [syscall.Access] for [ACCESS_READWRITE]/[ACCESS_READ]
// (as the Unix VFS does).
path := memory{mod}.readString(zPath, _MAX_PATHNAME)
fi, err := os.Stat(path)
var res uint32
switch {
case flags == _ACCESS_EXISTS:
switch {
case err == nil:
res = 1
case errors.Is(err, fs.ErrNotExist):
res = 0
default:
return uint32(IOERR_ACCESS)
}
case err == nil:
var want fs.FileMode = syscall.S_IRUSR
if flags == _ACCESS_READWRITE {
want |= syscall.S_IWUSR
}
if fi.IsDir() {
want |= syscall.S_IXUSR
}
if fi.Mode()&want == want {
res = 1
} else {
res = 0
}
case errors.Is(err, fs.ErrPermission):
res = 0
default:
return uint32(IOERR_ACCESS)
}
memory{mod}.writeUint32(pResOut, res)
return _OK
}
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
}
if flags&OPEN_CREATE != 0 {
oflags |= os.O_CREATE
}
if flags&OPEN_READONLY != 0 {
oflags |= os.O_RDONLY
}
if flags&OPEN_READWRITE != 0 {
oflags |= os.O_RDWR
}
var err error
var file *os.File
if zName == 0 {
file, err = os.CreateTemp("", "*.db")
} else {
name := memory{mod}.readString(zName, _MAX_PATHNAME)
file, err = os.OpenFile(name, oflags, 0600)
}
if err != nil {
return uint32(CANTOPEN)
}
if flags&OPEN_DELETEONCLOSE != 0 {
vfsOS.DeleteOnClose(file)
}
vfsFile.Open(ctx, mod, pFile, file)
if pOutFlags != 0 {
memory{mod}.writeUint32(pOutFlags, uint32(flags))
}
return _OK
}
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 {
err := vfsFile.Close(ctx, mod, pFile)
if err != nil {
return uint32(IOERR_CLOSE)
}
return _OK
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFile.GetOS(ctx, mod, pFile)
n, err := file.ReadAt(buf, int64(iOfst))
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return uint32(IOERR_READ)
}
for i := range buf[n:] {
buf[n+i] = 0
}
return uint32(IOERR_SHORT_READ)
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 {
buf := memory{mod}.view(zBuf, uint64(iAmt))
file := vfsFile.GetOS(ctx, mod, pFile)
_, err := file.WriteAt(buf, int64(iOfst))
if err != nil {
return uint32(IOERR_WRITE)
}
return _OK
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 {
file := vfsFile.GetOS(ctx, mod, pFile)
err := file.Truncate(int64(nByte))
if err != nil {
return uint32(IOERR_TRUNCATE)
}
return _OK
}
func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 {
file := vfsFile.GetOS(ctx, mod, pFile)
err := file.Sync()
if err != nil {
return uint32(IOERR_FSYNC)
}
return _OK
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 {
// This uses [os.File.Seek] because we don't care about the offset for reading/writing.
// But consider using [os.File.Stat] instead (as other VFSes do).
file := vfsFile.GetOS(ctx, mod, pFile)
off, err := file.Seek(0, io.SeekEnd)
if err != nil {
return uint32(IOERR_SEEK)
}
memory{mod}.writeUint64(pSize, uint64(off))
return _OK
}
func vfsFileControl(ctx context.Context, pFile, op, pArg uint32) uint32 {
// SQLite calls vfsFileControl with these opcodes:
// SQLITE_FCNTL_SIZE_HINT
// SQLITE_FCNTL_PRAGMA
// SQLITE_FCNTL_BUSYHANDLER
// SQLITE_FCNTL_HAS_MOVED
// SQLITE_FCNTL_SYNC
// SQLITE_FCNTL_COMMIT_PHASETWO
// SQLITE_FCNTL_PDB
return uint32(NOTFOUND)
}

View File

@@ -1,57 +0,0 @@
package sqlite3
import (
"context"
"os"
"github.com/tetratelabs/wazero/api"
)
func (vfsFileMethods) NewID(ctx context.Context, file *os.File) uint32 {
vfs := ctx.Value(vfsKey{}).(*vfsState)
// Find an empty slot.
for id, ptr := range vfs.files {
if ptr == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func (vfsFileMethods) Open(ctx context.Context, mod api.Module, pFile uint32, file *os.File) {
mem := memory{mod}
id := vfsFile.NewID(ctx, file)
mem.writeUint32(pFile+ptrlen, id)
mem.writeUint32(pFile+2*ptrlen, _NO_LOCK)
}
func (vfsFileMethods) Close(ctx context.Context, mod api.Module, pFile uint32) error {
mem := memory{mod}
id := mem.readUint32(pFile + ptrlen)
vfs := ctx.Value(vfsKey{}).(*vfsState)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
}
func (vfsFileMethods) GetOS(ctx context.Context, mod api.Module, pFile uint32) *os.File {
mem := memory{mod}
id := mem.readUint32(pFile + ptrlen)
vfs := ctx.Value(vfsKey{}).(*vfsState)
return vfs.files[id]
}
func (vfsFileMethods) GetLock(ctx context.Context, mod api.Module, pFile uint32) vfsLockState {
mem := memory{mod}
return vfsLockState(mem.readUint32(pFile + 2*ptrlen))
}
func (vfsFileMethods) SetLock(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) {
mem := memory{mod}
mem.writeUint32(pFile+2*ptrlen, uint32(lock))
}

View File

@@ -1,212 +0,0 @@
package sqlite3
import (
"context"
"os"
"github.com/tetratelabs/wazero/api"
)
const (
// No locks are held on the database.
// The database may be neither read nor written.
// Any internally cached data is considered suspect and subject to
// verification against the database file before being used.
// Other processes can read or write the database as their own locking
// states permit.
// This is the default state.
_NO_LOCK = 0
// The database may be read but not written.
// Any number of processes can hold SHARED locks at the same time,
// hence there can be many simultaneous readers.
// But no other thread or process is allowed to write to the database file
// while one or more SHARED locks are active.
_SHARED_LOCK = 1
// A RESERVED lock means that the process is planning on writing to the
// database file at some point in the future but that it is currently just
// reading from the file.
// Only a single RESERVED lock may be active at one time,
// though multiple SHARED locks can coexist with a single RESERVED lock.
// RESERVED differs from PENDING in that new SHARED locks can be acquired
// while there is a RESERVED lock.
_RESERVED_LOCK = 2
// A PENDING lock means that the process holding the lock wants to write to
// the database as soon as possible and is just waiting on all current
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
// No new SHARED locks are permitted against the database if a PENDING lock
// is active, though existing SHARED locks are allowed to continue.
_PENDING_LOCK = 3
// An EXCLUSIVE lock is needed in order to write to the database file.
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
// kind are allowed to coexist with an EXCLUSIVE lock.
// In order to maximize concurrency, SQLite works to minimize the amount of
// time that EXCLUSIVE locks are held.
_EXCLUSIVE_LOCK = 4
_PENDING_BYTE = 0x40000000
_RESERVED_BYTE = (_PENDING_BYTE + 1)
_SHARED_FIRST = (_PENDING_BYTE + 2)
_SHARED_SIZE = 510
)
type vfsLockState uint32
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check. SQLite never explicitly requests a pendig lock.
if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK {
panic(assertErr())
}
file := vfsFile.GetOS(ctx, mod, pFile)
cLock := vfsFile.GetLock(ctx, mod, pFile)
switch {
case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK:
// Connection state check.
panic(assertErr())
case cLock == _NO_LOCK && eLock > _SHARED_LOCK:
// We never move from unlocked to anything higher than a shared lock.
panic(assertErr())
case cLock != _SHARED_LOCK && eLock == _RESERVED_LOCK:
// A shared lock is always held when a reserved lock is requested.
panic(assertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if cLock >= eLock {
return _OK
}
switch eLock {
case _SHARED_LOCK:
// Must be unlocked to get SHARED.
if cLock != _NO_LOCK {
panic(assertErr())
}
// Test the PENDING lock before acquiring a new SHARED lock.
if locked, _ := vfsOS.CheckPendingLock(file); locked {
return uint32(BUSY)
}
if rc := vfsOS.GetSharedLock(file); rc != _OK {
return uint32(rc)
}
vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK)
return _OK
case _RESERVED_LOCK:
// Must be SHARED to get RESERVED.
if cLock != _SHARED_LOCK {
panic(assertErr())
}
if rc := vfsOS.GetReservedLock(file); rc != _OK {
return uint32(rc)
}
vfsFile.SetLock(ctx, mod, pFile, _RESERVED_LOCK)
return _OK
case _EXCLUSIVE_LOCK:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if cLock <= _NO_LOCK || cLock >= _EXCLUSIVE_LOCK {
panic(assertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if cLock == _RESERVED_LOCK {
if rc := vfsOS.GetPendingLock(file); rc != _OK {
return uint32(rc)
}
vfsFile.SetLock(ctx, mod, pFile, _PENDING_LOCK)
}
if rc := vfsOS.GetExclusiveLock(file); rc != _OK {
return uint32(rc)
}
vfsFile.SetLock(ctx, mod, pFile, _EXCLUSIVE_LOCK)
return _OK
default:
panic(assertErr())
}
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 {
// Argument check.
if eLock != _NO_LOCK && eLock != _SHARED_LOCK {
panic(assertErr())
}
file := vfsFile.GetOS(ctx, mod, pFile)
cLock := vfsFile.GetLock(ctx, mod, pFile)
// Connection state check.
if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK {
panic(assertErr())
}
// If we don't have a more restrictive lock, do nothing.
if cLock <= eLock {
return _OK
}
switch eLock {
case _SHARED_LOCK:
if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK {
return uint32(rc)
}
vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK)
return _OK
case _NO_LOCK:
rc := vfsOS.ReleaseLock(file, cLock)
vfsFile.SetLock(ctx, mod, pFile, _NO_LOCK)
return uint32(rc)
default:
panic(assertErr())
}
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 {
cLock := vfsFile.GetLock(ctx, mod, pFile)
if cLock > _SHARED_LOCK {
panic(assertErr())
}
file := vfsFile.GetOS(ctx, mod, pFile)
locked, rc := vfsOS.CheckReservedLock(file)
var res uint32
if locked {
res = 1
}
memory{mod}.writeUint32(pResOut, res)
return uint32(rc)
}
func (vfsOSMethods) GetSharedLock(file *os.File) xErrorCode {
// Acquire the SHARED lock.
return vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
func (vfsOSMethods) GetReservedLock(file *os.File) xErrorCode {
// Acquire the RESERVED lock.
return vfsOS.writeLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) GetPendingLock(file *os.File) xErrorCode {
// Acquire the PENDING lock.
return vfsOS.writeLock(file, _PENDING_BYTE, 1)
}
func (vfsOSMethods) CheckReservedLock(file *os.File) (bool, xErrorCode) {
// Test the RESERVED lock.
return vfsOS.checkLock(file, _RESERVED_BYTE, 1)
}
func (vfsOSMethods) CheckPendingLock(file *os.File) (bool, xErrorCode) {
// Test the PENDING lock.
return vfsOS.checkLock(file, _PENDING_BYTE, 1)
}

View File

@@ -1,128 +0,0 @@
package sqlite3
import (
"context"
"os"
"path/filepath"
"runtime"
"testing"
)
func Test_vfsLock(t *testing.T) {
switch runtime.GOOS {
case "linux", "darwin", "illumos", "windows":
break
default:
t.Skip("OS lacks OFD locks")
}
name := filepath.Join(t.TempDir(), "test.db")
// Create a temporary file.
file1, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
t.Fatal(err)
}
defer file1.Close()
// Open the temporary file again.
file2, err := os.OpenFile(name, os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file2.Close()
const (
pFile1 = 4
pFile2 = 16
pOutput = 32
)
mem := newMemory(128)
ctx, vfs := vfsContext(context.TODO())
defer vfs.Close()
vfsFile.Open(ctx, mem.mod, pFile1, file1)
vfsFile.Open(ctx, mem.mod, pFile2, file2)
rc := vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mem.mod, pFile2, _RESERVED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(ctx, mem.mod, pFile2, _EXCLUSIVE_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK)
if rc == _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got == 0 {
t.Error("file wasn't locked")
}
rc = vfsUnlock(ctx, mem.mod, pFile2, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := mem.readUint32(pOutput); got != 0 {
t.Error("file was locked")
}
rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK)
if rc != _OK {
t.Fatal("returned", rc)
}
}

View File

@@ -1,137 +0,0 @@
//go:build unix
package sqlite3
import (
"os"
"runtime"
"syscall"
)
func (vfsOSMethods) DeleteOnClose(file *os.File) {
_ = os.Remove(file.Name())
}
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Acquire the EXCLUSIVE lock.
return vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Downgrade to a SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// In theory, the downgrade to a SHARED cannot fail because another
// process is holding an incompatible lock. If it does, this
// indicates that the other process is not following the locking
// protocol. If this happens, return IOERR_RDLOCK. Returning
// BUSY would confuse the upper layer.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
return vfsOS.unlock(file, _PENDING_BYTE, 2)
}
func (vfsOSMethods) ReleaseLock(file *os.File, _ vfsLockState) xErrorCode {
// Release all locks.
return vfsOS.unlock(file, 0, 0)
}
func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode {
err := vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_UNLCK,
Start: start,
Len: len,
})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (vfsOSMethods) readLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}), IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len int64) xErrorCode {
return vfsOS.lockErrorCode(vfsOS.fcntlSetLock(file, &syscall.Flock_t{
Type: syscall.F_WRLCK,
Start: start,
Len: len,
}), IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) {
lock := syscall.Flock_t{
Type: syscall.F_RDLCK,
Start: start,
Len: len,
}
if vfsOS.fcntlGetLock(file, &lock) != nil {
return false, IOERR_CHECKRESERVEDLOCK
}
return lock.Type != syscall.F_UNLCK, _OK
}
func (vfsOSMethods) fcntlGetLock(file *os.File, lock *syscall.Flock_t) error {
var F_OFD_GETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_OFD_GETLK = 36 // F_OFD_GETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_OFD_GETLK = 92 // F_OFD_GETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_OFD_GETLK = 47 // F_OFD_GETLK
default:
return notImplErr
}
return syscall.FcntlFlock(file.Fd(), F_OFD_GETLK, lock)
}
func (vfsOSMethods) fcntlSetLock(file *os.File, lock *syscall.Flock_t) error {
var F_OFD_SETLK int
switch runtime.GOOS {
case "linux":
// https://github.com/torvalds/linux/blob/master/include/uapi/asm-generic/fcntl.h
F_OFD_SETLK = 37 // F_OFD_SETLK
case "darwin":
// https://github.com/apple/darwin-xnu/blob/main/bsd/sys/fcntl.h
F_OFD_SETLK = 90 // F_OFD_SETLK
case "illumos":
// https://github.com/illumos/illumos-gate/blob/master/usr/src/uts/common/sys/fcntl.h
F_OFD_SETLK = 48 // F_OFD_SETLK
default:
return notImplErr
}
return syscall.FcntlFlock(file.Fd(), F_OFD_SETLK, lock)
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case
syscall.EACCES,
syscall.EAGAIN,
syscall.EBUSY,
syscall.EINTR,
syscall.ENOLCK,
syscall.EDEADLK,
syscall.ETIMEDOUT:
return xErrorCode(BUSY)
case syscall.EPERM:
return xErrorCode(PERM)
}
}
return def
}

View File

@@ -1,102 +0,0 @@
package sqlite3
import (
"os"
"syscall"
"golang.org/x/sys/windows"
)
func (vfsOSMethods) DeleteOnClose(file *os.File) {}
func (vfsOSMethods) GetExclusiveLock(file *os.File) xErrorCode {
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Acquire the EXCLUSIVE lock.
rc := vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc != _OK {
vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE)
}
return rc
}
func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode {
if state >= _EXCLUSIVE_LOCK {
// Release the SHARED lock.
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
// Reacquire the SHARED lock.
if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE); rc != _OK {
// This should never happen.
// We should always be able to reacquire the read lock.
return IOERR_RDLOCK
}
}
// Release the PENDING and RESERVED locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) ReleaseLock(file *os.File, state vfsLockState) xErrorCode {
// Release all locks.
if state >= _RESERVED_LOCK {
vfsOS.unlock(file, _RESERVED_BYTE, 1)
}
if state >= _SHARED_LOCK {
vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE)
}
if state >= _PENDING_LOCK {
vfsOS.unlock(file, _PENDING_BYTE, 1)
}
return _OK
}
func (vfsOSMethods) unlock(file *os.File, start, len uint32) xErrorCode {
err := windows.UnlockFileEx(windows.Handle(file.Fd()),
0, len, 0, &windows.Overlapped{Offset: start})
if err != nil {
return IOERR_UNLOCK
}
return _OK
}
func (vfsOSMethods) readLock(file *os.File, start, len uint32) xErrorCode {
return vfsOS.lockErrorCode(windows.LockFileEx(windows.Handle(file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_RDLOCK)
}
func (vfsOSMethods) writeLock(file *os.File, start, len uint32) xErrorCode {
return vfsOS.lockErrorCode(windows.LockFileEx(windows.Handle(file.Fd()),
windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK,
0, len, 0, &windows.Overlapped{Offset: start}),
IOERR_LOCK)
}
func (vfsOSMethods) checkLock(file *os.File, start, len uint32) (bool, xErrorCode) {
rc := vfsOS.readLock(file, start, len)
if rc == _OK {
vfsOS.unlock(file, start, len)
}
return rc != _OK, _OK
}
func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode {
if err == nil {
return _OK
}
if errno, _ := err.(syscall.Errno); errno == windows.ERROR_INVALID_HANDLE {
return def
}
return xErrorCode(BUSY)
}