mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 22:19:14 +00:00
Compare commits
88 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1427d30541 | ||
|
|
d3730341f0 | ||
|
|
78ac2386f6 | ||
|
|
632ea933b3 | ||
|
|
0f7fa6ebc9 | ||
|
|
6f7f776488 | ||
|
|
f6d7c5e9c5 | ||
|
|
1cc7ecfe8d | ||
|
|
3844e81404 | ||
|
|
fec1f8d32a | ||
|
|
31572e6095 | ||
|
|
4aee38b957 | ||
|
|
232a7705b5 | ||
|
|
a6c2fccd74 | ||
|
|
6a982559cd | ||
|
|
c7904d30de | ||
|
|
ce4386604d | ||
|
|
26b62c520d | ||
|
|
738714bf32 | ||
|
|
41b020bafc | ||
|
|
d0e720272b | ||
|
|
76171da12b | ||
|
|
dcc845d684 | ||
|
|
f1b42c26d5 | ||
|
|
1e94407ae7 | ||
|
|
eb8d9b95fd | ||
|
|
04037a75ed | ||
|
|
2472ceb0a0 | ||
|
|
bfe9bfde2e | ||
|
|
f07e82e361 | ||
|
|
fbbbe5a631 | ||
|
|
5ea603ed78 | ||
|
|
401cb77e38 | ||
|
|
6511175011 | ||
|
|
f7d987fdf1 | ||
|
|
00ba681bb5 | ||
|
|
d4d4533a41 | ||
|
|
ec9533b13f | ||
|
|
8fe77a065c | ||
|
|
7bf5312bd4 | ||
|
|
ae7b74d858 | ||
|
|
9a8de3ad13 | ||
|
|
05737e6025 | ||
|
|
ac2836bb82 | ||
|
|
d0d4b0e1a2 | ||
|
|
dc3dc6853d | ||
|
|
830240c368 | ||
|
|
dedec8682b | ||
|
|
a33b828e13 | ||
|
|
8b2e96dedc | ||
|
|
f1c46db512 | ||
|
|
7ca9d79424 | ||
|
|
254d473546 | ||
|
|
5639fc1ff8 | ||
|
|
ae4954d09b | ||
|
|
45937d9749 | ||
|
|
eee71e06aa | ||
|
|
9e7b6bb8ea | ||
|
|
597178f80d | ||
|
|
cc2d16ac83 | ||
|
|
cfb69e4ce7 | ||
|
|
e6969432e3 | ||
|
|
2b3da350cc | ||
|
|
336ba87d56 | ||
|
|
dd4823ebf0 | ||
|
|
663b23ff3b | ||
|
|
4e2ce6c635 | ||
|
|
66effb4249 | ||
|
|
e1cce83f71 | ||
|
|
df953b31c2 | ||
|
|
67cc3d35d5 | ||
|
|
6846b72b31 | ||
|
|
c94cdaf720 | ||
|
|
f6a887dd1c | ||
|
|
2a010a2022 | ||
|
|
c86b06b048 | ||
|
|
a44a13a506 | ||
|
|
4604719966 | ||
|
|
03168d5d34 | ||
|
|
be4b6304f9 | ||
|
|
b5e678a40a | ||
|
|
2fc4698ddc | ||
|
|
bd86539577 | ||
|
|
7a785d9aec | ||
|
|
59f79e8e74 | ||
|
|
40457721d7 | ||
|
|
18eeb85783 | ||
|
|
b36536979b |
2
.github/FUNDING.yml
vendored
2
.github/FUNDING.yml
vendored
@@ -1 +1 @@
|
||||
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6
|
||||
custom: https://www.paypal.com/donate?hosted_button_id=33P59ELZWGMK6
|
||||
6
.github/workflows/go.yml
vendored
6
.github/workflows/go.yml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -33,8 +34,9 @@ jobs:
|
||||
- name: Download
|
||||
run: go mod download
|
||||
|
||||
- name: Verify
|
||||
run: go mod verify
|
||||
# Fixed in go 1.21: https://go.dev/issue/54372
|
||||
# - name: Verify
|
||||
# run: go mod verify
|
||||
|
||||
- name: Vet
|
||||
run: go vet ./...
|
||||
|
||||
31
README.md
31
README.md
@@ -15,28 +15,27 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
|
||||
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
|
||||
embeds a build of SQLite into your application.
|
||||
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
|
||||
### Caveats
|
||||
|
||||
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
|
||||
with a [pure Go](internal/vfs/) implementation.
|
||||
This has numerous benefits, but also comes with some drawbacks.
|
||||
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html)
|
||||
(aka VFS) with a [pure Go](vfs/) implementation.
|
||||
This has benefits, but also comes with some drawbacks.
|
||||
|
||||
#### Write-Ahead Logging
|
||||
|
||||
Because WASM does not support shared memory,
|
||||
[WAL](https://www.sqlite.org/wal.html) support is [limited](https://www.sqlite.org/wal.html#noshm).
|
||||
|
||||
To work around this limitation, SQLite is compiled with
|
||||
[`SQLITE_DEFAULT_LOCKING_MODE=1`](https://www.sqlite.org/compile.html#default_locking_mode),
|
||||
making `EXCLUSIVE` the default locking mode.
|
||||
For non-WAL databases, `NORMAL` locking mode can be activated with
|
||||
[`PRAGMA locking_mode=NORMAL`](https://www.sqlite.org/pragma.html#pragma_locking_mode).
|
||||
To work around this limitation, SQLite is [patched](sqlite3/locking_mode.patch)
|
||||
to always use `EXCLUSIVE` locking mode for WAL databases.
|
||||
|
||||
Because connection pooling is incompatible with `EXCLUSIVE` locking mode,
|
||||
the `database/sql` driver defaults to `NORMAL` locking mode.
|
||||
To open WAL databases, or use `EXCLUSIVE` locking mode,
|
||||
disable connection pooling by calling
|
||||
to open WAL databases you should disable connection pooling by calling
|
||||
[`db.SetMaxOpenConns(1)`](https://pkg.go.dev/database/sql#DB.SetMaxOpenConns).
|
||||
|
||||
#### POSIX Advisory Locks
|
||||
@@ -55,7 +54,7 @@ BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
|
||||
|
||||
#### Testing
|
||||
|
||||
The pure Go VFS is stress tested by running an unmodified build of SQLite's
|
||||
The pure Go VFS is tested by running an unmodified build of SQLite's
|
||||
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
|
||||
on Linux, macOS and Windows.
|
||||
Performance is tested by running
|
||||
@@ -64,16 +63,16 @@ Performance is tested by running
|
||||
### Roadmap
|
||||
|
||||
- [ ] advanced SQLite features
|
||||
- [x] custom functions
|
||||
- [x] nested transactions
|
||||
- [x] incremental BLOB I/O
|
||||
- [x] online backup
|
||||
- [ ] session extension
|
||||
- [ ] custom SQL functions
|
||||
- [ ] custom VFSes
|
||||
- [ ] in-memory VFS
|
||||
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
- [x] custom VFS API
|
||||
- [x] in-memory VFS
|
||||
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
|
||||
- [ ] custom VFS API
|
||||
|
||||
### Alternatives
|
||||
|
||||
|
||||
26
backup.go
26
backup.go
@@ -1,6 +1,6 @@
|
||||
package sqlite3
|
||||
|
||||
// Backup is a handle to an open BLOB.
|
||||
// Backup is an handle to an ongoing online backup operation.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup.html
|
||||
type Backup struct {
|
||||
@@ -11,7 +11,7 @@ type Backup struct {
|
||||
|
||||
// Backup backs up srcDB on the src connection to the "main" database in dstURI.
|
||||
//
|
||||
// Backup calls [Open] to open the SQLite database file dstURI,
|
||||
// Backup opens the SQLite database file dstURI,
|
||||
// and blocks until the entire backup is complete.
|
||||
// Use [Conn.BackupInit] for incremental backup.
|
||||
//
|
||||
@@ -28,7 +28,7 @@ func (src *Conn) Backup(srcDB, dstURI string) error {
|
||||
|
||||
// Restore restores dstDB on the dst connection from the "main" database in srcURI.
|
||||
//
|
||||
// Restore calls [Open] to open the SQLite database file srcURI,
|
||||
// Restore opens the SQLite database file srcURI,
|
||||
// and blocks until the entire restore is complete.
|
||||
//
|
||||
// https://www.sqlite.org/backup.html
|
||||
@@ -48,7 +48,7 @@ func (dst *Conn) Restore(dstDB, srcURI string) error {
|
||||
|
||||
// BackupInit initializes a backup operation to copy the content of one database into another.
|
||||
//
|
||||
// BackupInit calls [Open] to open the SQLite database file dstURI,
|
||||
// BackupInit opens the SQLite database file dstURI,
|
||||
// then initializes a backup that copies the contents of srcDB on the src connection
|
||||
// to the "main" database in dstURI.
|
||||
//
|
||||
@@ -74,16 +74,16 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
|
||||
r := c.call(c.api.backupInit,
|
||||
uint64(dst), uint64(dstPtr),
|
||||
uint64(src), uint64(srcPtr))
|
||||
if r[0] == 0 {
|
||||
if r == 0 {
|
||||
defer c.closeDB(other)
|
||||
r = c.call(c.api.errcode, uint64(dst))
|
||||
return nil, c.module.error(r[0], dst)
|
||||
return nil, c.sqlite.error(r, dst)
|
||||
}
|
||||
|
||||
return &Backup{
|
||||
c: c,
|
||||
otherc: other,
|
||||
handle: uint32(r[0]),
|
||||
handle: uint32(r),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func (b *Backup) Close() error {
|
||||
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
|
||||
b.c.closeDB(b.otherc)
|
||||
b.handle = 0
|
||||
return b.c.error(r[0])
|
||||
return b.c.error(r)
|
||||
}
|
||||
|
||||
// Step copies up to nPage pages between the source and destination databases.
|
||||
@@ -109,10 +109,10 @@ func (b *Backup) Close() error {
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
|
||||
func (b *Backup) Step(nPage int) (done bool, err error) {
|
||||
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
|
||||
if r[0] == _DONE {
|
||||
if r == _DONE {
|
||||
return true, nil
|
||||
}
|
||||
return false, b.c.error(r[0])
|
||||
return false, b.c.error(r)
|
||||
}
|
||||
|
||||
// Remaining returns the number of pages still to be backed up
|
||||
@@ -121,7 +121,7 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
|
||||
func (b *Backup) Remaining() int {
|
||||
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// PageCount returns the total number of pages in the source database
|
||||
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
|
||||
func (b *Backup) PageCount() int {
|
||||
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
|
||||
return int(r[0])
|
||||
r := b.c.call(b.c.api.backupPageCount, uint64(b.handle))
|
||||
return int(r)
|
||||
}
|
||||
|
||||
22
blob.go
22
blob.go
@@ -11,7 +11,7 @@ import (
|
||||
// [database/sql.DB.Exec] and similar methods.
|
||||
type ZeroBlob int64
|
||||
|
||||
// Blob is a handle to an open BLOB.
|
||||
// Blob is an handle to an open BLOB.
|
||||
//
|
||||
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
|
||||
//
|
||||
@@ -45,13 +45,13 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
|
||||
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
|
||||
uint64(row), flags, uint64(blobPtr))
|
||||
|
||||
if err := c.error(r[0]); err != nil {
|
||||
if err := c.error(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blob := Blob{c: c}
|
||||
blob.handle = util.ReadUint32(c.mod, blobPtr)
|
||||
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
|
||||
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle)))
|
||||
return &blob, nil
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func (b *Blob) Close() error {
|
||||
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
|
||||
|
||||
b.handle = 0
|
||||
return b.c.error(r[0])
|
||||
return b.c.error(r)
|
||||
}
|
||||
|
||||
// Size returns the size of the BLOB in bytes.
|
||||
@@ -97,7 +97,7 @@ func (b *Blob) Read(p []byte) (n int, err error) {
|
||||
|
||||
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
|
||||
uint64(ptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r[0])
|
||||
err = b.c.error(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -130,7 +130,7 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
|
||||
for want > 0 {
|
||||
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
|
||||
uint64(ptr), uint64(want), uint64(b.offset))
|
||||
err = b.c.error(r[0])
|
||||
err = b.c.error(r)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (b *Blob) Write(p []byte) (n int, err error) {
|
||||
|
||||
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
|
||||
uint64(ptr), uint64(len(p)), uint64(b.offset))
|
||||
err = b.c.error(r[0])
|
||||
err = b.c.error(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -193,7 +193,7 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
|
||||
if m > 0 {
|
||||
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
|
||||
uint64(ptr), uint64(m), uint64(b.offset))
|
||||
err := b.c.error(r[0])
|
||||
err := b.c.error(r)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
@@ -240,8 +240,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/blob_reopen.html
|
||||
func (b *Blob) Reopen(row int64) error {
|
||||
r := b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))
|
||||
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
|
||||
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row)))
|
||||
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle)))
|
||||
b.offset = 0
|
||||
return b.c.error(r[0])
|
||||
return err
|
||||
}
|
||||
|
||||
46
conn.go
46
conn.go
@@ -19,7 +19,7 @@ import (
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/sqlite3.html
|
||||
type Conn struct {
|
||||
*module
|
||||
*sqlite
|
||||
|
||||
interrupt context.Context
|
||||
waiter chan struct{}
|
||||
@@ -39,7 +39,7 @@ func Open(filename string) (*Conn, error) {
|
||||
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
|
||||
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
|
||||
//
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
|
||||
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/open.html
|
||||
func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
@@ -50,19 +50,19 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
|
||||
}
|
||||
|
||||
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
|
||||
mod, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if conn == nil {
|
||||
mod.close()
|
||||
sqlite.close()
|
||||
} else {
|
||||
runtime.SetFinalizer(conn, util.Finalizer[Conn](3))
|
||||
}
|
||||
}()
|
||||
|
||||
c := &Conn{module: mod}
|
||||
c := &Conn{sqlite: sqlite}
|
||||
c.arena = c.newArena(1024)
|
||||
c.handle, err = c.openDB(filename, flags)
|
||||
if err != nil {
|
||||
@@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
|
||||
|
||||
handle := util.ReadUint32(c.mod, connPtr)
|
||||
if err := c.module.error(r[0], handle); err != nil {
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
c.closeDB(handle)
|
||||
return 0, err
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
c.arena.reset()
|
||||
pragmaPtr := c.arena.string(pragmas.String())
|
||||
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
|
||||
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
|
||||
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
|
||||
if errors.Is(err, ERROR) {
|
||||
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
|
||||
|
||||
func (c *Conn) closeDB(handle uint32) {
|
||||
r := c.call(c.api.closeZombie, uint64(handle))
|
||||
if err := c.module.error(r[0], handle); err != nil {
|
||||
if err := c.sqlite.error(r, handle); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -137,13 +137,13 @@ func (c *Conn) Close() error {
|
||||
c.pending = nil
|
||||
|
||||
r := c.call(c.api.close, uint64(c.handle))
|
||||
if err := c.error(r[0]); err != nil {
|
||||
if err := c.error(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.handle = 0
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return c.module.close()
|
||||
return c.close()
|
||||
}
|
||||
|
||||
// Exec is a convenience function that allows an application to run
|
||||
@@ -156,7 +156,7 @@ func (c *Conn) Exec(sql string) error {
|
||||
sqlPtr := c.arena.string(sql)
|
||||
|
||||
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
|
||||
return c.error(r[0])
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// Prepare calls [Conn.PrepareFlags] with no flags.
|
||||
@@ -189,7 +189,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
i := util.ReadUint32(c.mod, tailPtr)
|
||||
tail = sql[i-sqlPtr:]
|
||||
|
||||
if err := c.error(r[0], sql); err != nil {
|
||||
if err := c.error(r, sql); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if stmt.handle == 0 {
|
||||
@@ -203,7 +203,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
|
||||
// https://www.sqlite.org/c3ref/get_autocommit.html
|
||||
func (c *Conn) GetAutocommit() bool {
|
||||
r := c.call(c.api.autocommit, uint64(c.handle))
|
||||
return r[0] != 0
|
||||
return r != 0
|
||||
}
|
||||
|
||||
// LastInsertRowID returns the rowid of the most recent successful INSERT
|
||||
@@ -212,7 +212,7 @@ func (c *Conn) GetAutocommit() bool {
|
||||
// https://www.sqlite.org/c3ref/last_insert_rowid.html
|
||||
func (c *Conn) LastInsertRowID() int64 {
|
||||
r := c.call(c.api.lastRowid, uint64(c.handle))
|
||||
return int64(r[0])
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// Changes returns the number of rows modified, inserted or deleted
|
||||
@@ -222,7 +222,7 @@ func (c *Conn) LastInsertRowID() int64 {
|
||||
// https://www.sqlite.org/c3ref/changes.html
|
||||
func (c *Conn) Changes() int64 {
|
||||
r := c.call(c.api.changes, uint64(c.handle))
|
||||
return int64(r[0])
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// SetInterrupt interrupts a long-running query when a context is done.
|
||||
@@ -278,7 +278,8 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
break
|
||||
|
||||
case <-ctx.Done(): // Done was closed.
|
||||
buf := util.View(c.mod, c.handle+c.api.interrupt, 4)
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
// Wait for the next call to SetInterrupt.
|
||||
<-waiter
|
||||
@@ -294,7 +295,8 @@ func (c *Conn) checkInterrupt() bool {
|
||||
if c.interrupt == nil || c.interrupt.Err() == nil {
|
||||
return false
|
||||
}
|
||||
buf := util.View(c.mod, c.handle+c.api.interrupt, 4)
|
||||
const isInterruptedOffset = 288
|
||||
buf := util.View(c.mod, c.handle+isInterruptedOffset, 4)
|
||||
(*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1)
|
||||
return true
|
||||
}
|
||||
@@ -317,21 +319,27 @@ func (c *Conn) Pragma(str string) ([]string, error) {
|
||||
}
|
||||
|
||||
func (c *Conn) error(rc uint64, sql ...string) error {
|
||||
return c.module.error(rc, c.handle, sql...)
|
||||
return c.sqlite.error(rc, c.handle, sql...)
|
||||
}
|
||||
|
||||
// DriverConn is implemented by the SQLite [database/sql] driver connection.
|
||||
//
|
||||
// It can be used to access advanced SQLite features like
|
||||
// [savepoints] and [incremental BLOB I/O].
|
||||
// [savepoints], [online backup] and [incremental BLOB I/O].
|
||||
//
|
||||
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
|
||||
// [online backup]: https://www.sqlite.org/backup.html
|
||||
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
|
||||
type DriverConn interface {
|
||||
driver.Conn
|
||||
driver.ConnBeginTx
|
||||
driver.ExecerContext
|
||||
driver.ConnPrepareContext
|
||||
|
||||
SetInterrupt(ctx context.Context) (old context.Context)
|
||||
|
||||
Savepoint() Savepoint
|
||||
Backup(srcDB, dstURI string) error
|
||||
Restore(dstDB, srcURI string) error
|
||||
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
|
||||
}
|
||||
|
||||
55
const.go
55
const.go
@@ -137,34 +137,23 @@ const (
|
||||
AUTH_USER ExtendedErrorCode = xErrorCode(AUTH) | (1 << 8)
|
||||
)
|
||||
|
||||
// OpenFlag is a flag for a file open operation.
|
||||
// OpenFlag is a flag for the [OpenFlags] function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
|
||||
type OpenFlag uint32
|
||||
|
||||
const (
|
||||
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
|
||||
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
|
||||
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
|
||||
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
|
||||
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
|
||||
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
|
||||
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
|
||||
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
|
||||
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
|
||||
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
|
||||
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
|
||||
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
|
||||
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_EXRESCODE OpenFlag = 0x02000000 /* Extended result codes */
|
||||
)
|
||||
|
||||
// PrepareFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
@@ -178,6 +167,18 @@ const (
|
||||
PREPARE_NO_VTAB PrepareFlag = 0x04
|
||||
)
|
||||
|
||||
// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_deterministic.html
|
||||
type FunctionFlag uint32
|
||||
|
||||
const (
|
||||
DETERMINISTIC FunctionFlag = 0x000000800
|
||||
DIRECTONLY FunctionFlag = 0x000080000
|
||||
SUBTYPE FunctionFlag = 0x000100000
|
||||
INNOCUOUS FunctionFlag = 0x000200000
|
||||
)
|
||||
|
||||
// Datatype is a fundamental datatype of SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_blob.html
|
||||
@@ -193,18 +194,18 @@ const (
|
||||
|
||||
// String implements the [fmt.Stringer] interface.
|
||||
func (t Datatype) String() string {
|
||||
const name = "INTEGERFLOATTEXTBLOBNULL"
|
||||
const name = "INTEGERFLOATEXTBLOBNULL"
|
||||
switch t {
|
||||
case INTEGER:
|
||||
return name[0:7]
|
||||
case FLOAT:
|
||||
return name[7:12]
|
||||
case TEXT:
|
||||
return name[12:16]
|
||||
return name[11:15]
|
||||
case BLOB:
|
||||
return name[16:20]
|
||||
return name[15:19]
|
||||
case NULL:
|
||||
return name[20:24]
|
||||
return name[19:23]
|
||||
}
|
||||
return strconv.FormatUint(uint64(t), 10)
|
||||
}
|
||||
|
||||
174
context.go
Normal file
174
context.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Context is the context in which an SQL function executes.
|
||||
// An SQLite [Context] is in no way related to a Go [context.Context].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/context.html
|
||||
type Context struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// SetAuxData saves metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) SetAuxData(n int, data any) {
|
||||
ptr := util.AddHandle(c.ctx, data)
|
||||
c.call(c.api.setAuxData, uint64(c.handle), uint64(n), uint64(ptr))
|
||||
}
|
||||
|
||||
// GetAuxData returns metadata for argument n of the function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/get_auxdata.html
|
||||
func (c Context) GetAuxData(n int) any {
|
||||
ptr := uint32(c.call(c.api.getAuxData, uint64(c.handle), uint64(n)))
|
||||
return util.GetHandle(c.ctx, ptr)
|
||||
}
|
||||
|
||||
// ResultBool sets the result of the function to a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are stored as integers 0 (false) and 1 (true).
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBool(value bool) {
|
||||
var i int64
|
||||
if value {
|
||||
i = 1
|
||||
}
|
||||
c.ResultInt64(i)
|
||||
}
|
||||
|
||||
// ResultInt sets the result of the function to an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt(value int) {
|
||||
c.ResultInt64(int64(value))
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of the function to an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultInt64(value int64) {
|
||||
c.call(c.api.resultInteger,
|
||||
uint64(c.handle), uint64(value))
|
||||
}
|
||||
|
||||
// ResultFloat sets the result of the function to a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultFloat(value float64) {
|
||||
c.call(c.api.resultFloat,
|
||||
uint64(c.handle), math.Float64bits(value))
|
||||
}
|
||||
|
||||
// ResultText sets the result of the function to a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultText(value string) {
|
||||
ptr := c.newString(value)
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of the function to a []byte.
|
||||
// Returning a nil slice is the same as calling [Context.ResultNull].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultBlob(value []byte) {
|
||||
ptr := c.newBytes(value)
|
||||
c.call(c.api.resultBlob,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(value)),
|
||||
uint64(c.api.destructor))
|
||||
}
|
||||
|
||||
// BindZeroBlob sets the result of the function to a zero-filled, length n BLOB.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultZeroBlob(n int64) {
|
||||
c.call(c.api.resultZeroBlob,
|
||||
uint64(c.handle), uint64(n))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of the function to NULL.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultNull() {
|
||||
c.call(c.api.resultNull,
|
||||
uint64(c.handle))
|
||||
}
|
||||
|
||||
// ResultTime sets the result of the function to a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultTime(value time.Time, format TimeFormat) {
|
||||
if format == TimeFormatDefault {
|
||||
c.resultRFC3339Nano(value)
|
||||
return
|
||||
}
|
||||
switch v := format.Encode(value).(type) {
|
||||
case string:
|
||||
c.ResultText(v)
|
||||
case int64:
|
||||
c.ResultInt64(v)
|
||||
case float64:
|
||||
c.ResultFloat(v)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) resultRFC3339Nano(value time.Time) {
|
||||
const maxlen = uint64(len(time.RFC3339Nano))
|
||||
|
||||
ptr := c.new(maxlen)
|
||||
buf := util.View(c.mod, ptr, maxlen)
|
||||
buf = value.AppendFormat(buf[:0], time.RFC3339Nano)
|
||||
|
||||
c.call(c.api.resultText,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(buf)),
|
||||
uint64(c.api.destructor), _UTF8)
|
||||
}
|
||||
|
||||
// ResultError sets the result of the function an error.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/result_blob.html
|
||||
func (c Context) ResultError(err error) {
|
||||
if errors.Is(err, NOMEM) {
|
||||
c.call(c.api.resultErrorMem, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, TOOBIG) {
|
||||
c.call(c.api.resultErrorBig, uint64(c.handle))
|
||||
return
|
||||
}
|
||||
|
||||
str := err.Error()
|
||||
ptr := c.newString(str)
|
||||
c.call(c.api.resultError,
|
||||
uint64(c.handle), uint64(ptr), uint64(len(str)))
|
||||
c.free(ptr)
|
||||
|
||||
var code uint64
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
switch {
|
||||
case errors.As(err, &xcode):
|
||||
code = uint64(xcode)
|
||||
case errors.As(err, &ecode):
|
||||
code = uint64(ecode)
|
||||
}
|
||||
if code != 0 {
|
||||
c.call(c.api.resultErrorCode,
|
||||
uint64(c.handle), code)
|
||||
}
|
||||
}
|
||||
219
driver/driver.go
219
driver/driver.go
@@ -14,10 +14,12 @@
|
||||
//
|
||||
// [PRAGMA] statements can be specified using "_pragma":
|
||||
//
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
|
||||
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
|
||||
//
|
||||
// If no PRAGMAs are specifed, a busy timeout of 1 minute
|
||||
// and normal locking mode are used.
|
||||
// If no PRAGMAs are specified, a busy timeout of 1 minute is set.
|
||||
//
|
||||
// Order matters:
|
||||
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
|
||||
//
|
||||
// [URI]: https://www.sqlite.org/uri.html
|
||||
// [PRAGMA]: https://www.sqlite.org/pragma.html
|
||||
@@ -46,12 +48,13 @@ type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.conn, err = sqlite3.Open(name)
|
||||
c.Conn, err = sqlite3.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pragmas []string
|
||||
var pragmas bool
|
||||
c.txBegin = "BEGIN"
|
||||
if strings.HasPrefix(name, "file:") {
|
||||
if _, after, ok := strings.Cut(name, "?"); ok {
|
||||
query, _ := url.ParseQuery(after)
|
||||
@@ -66,21 +69,18 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", s)
|
||||
}
|
||||
|
||||
pragmas = query["_pragma"]
|
||||
pragmas = len(query["_pragma"]) > 0
|
||||
}
|
||||
}
|
||||
if len(pragmas) == 0 {
|
||||
err := c.conn.Exec(`
|
||||
PRAGMA locking_mode=normal;
|
||||
PRAGMA busy_timeout=60000;
|
||||
`)
|
||||
if !pragmas {
|
||||
err := c.Conn.Exec(`PRAGMA busy_timeout=60000`)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
c.reusable = true
|
||||
} else {
|
||||
s, _, err := c.conn.Prepare(`
|
||||
s, _, err := c.Conn.Prepare(`
|
||||
SELECT * FROM
|
||||
PRAGMA_locking_mode,
|
||||
PRAGMA_query_only;
|
||||
@@ -99,11 +99,11 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
conn *sqlite3.Conn
|
||||
*sqlite3.Conn
|
||||
txBegin string
|
||||
txCommit string
|
||||
txRollback string
|
||||
@@ -113,25 +113,21 @@ type conn struct {
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.ExecerContext = conn{}
|
||||
_ driver.ConnBeginTx = conn{}
|
||||
_ driver.Validator = conn{}
|
||||
_ sqlite3.DriverConn = conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
func (c conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c conn) IsValid() bool {
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
txBegin := c.txBegin
|
||||
c.txCommit = `COMMIT`
|
||||
c.txRollback = `ROLLBACK`
|
||||
@@ -155,33 +151,43 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro
|
||||
break
|
||||
}
|
||||
|
||||
err := c.conn.Exec(txBegin)
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
err := c.Conn.Exec(txBegin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c conn) Commit() error {
|
||||
err := c.conn.Exec(c.txCommit)
|
||||
if err != nil && !c.conn.GetAutocommit() {
|
||||
func (c *conn) Commit() error {
|
||||
err := c.Conn.Exec(c.txCommit)
|
||||
if err != nil && !c.GetAutocommit() {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c conn) Rollback() error {
|
||||
return c.conn.Exec(c.txRollback)
|
||||
func (c *conn) Rollback() error {
|
||||
return c.Conn.Exec(c.txRollback)
|
||||
}
|
||||
|
||||
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
s, tail, err := c.conn.Prepare(query)
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
s, tail, err := c.Conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tail != "" {
|
||||
// Check if the tail contains any SQL.
|
||||
st, _, err := c.conn.Prepare(tail)
|
||||
st, _, err := c.Conn.Prepare(tail)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
@@ -192,61 +198,46 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, util.TailErr
|
||||
}
|
||||
}
|
||||
return stmt{s, c.conn}, nil
|
||||
return &stmt{s, c.Conn}, nil
|
||||
}
|
||||
|
||||
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
|
||||
return c.Prepare(query)
|
||||
}
|
||||
|
||||
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if len(args) != 0 {
|
||||
// Slow path.
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
old := c.conn.SetInterrupt(ctx)
|
||||
defer c.conn.SetInterrupt(old)
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
err := c.conn.Exec(query)
|
||||
err := c.Conn.Exec(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
c.conn.LastInsertRowID(),
|
||||
c.conn.Changes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c conn) Savepoint() sqlite3.Savepoint {
|
||||
return c.conn.Savepoint()
|
||||
}
|
||||
|
||||
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
|
||||
return c.conn.OpenBlob(db, table, column, row, write)
|
||||
return newResult(c.Conn), nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
_ driver.NamedValueChecker = stmt{}
|
||||
_ driver.StmtExecContext = &stmt{}
|
||||
_ driver.StmtQueryContext = &stmt{}
|
||||
_ driver.NamedValueChecker = &stmt{}
|
||||
)
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
func (s *stmt) Close() error {
|
||||
return s.Stmt.Close()
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
n := s.stmt.BindCount()
|
||||
func (s *stmt) NumInput() int {
|
||||
n := s.Stmt.BindCount()
|
||||
for i := 1; i <= n; i++ {
|
||||
if s.stmt.BindName(i) != "" {
|
||||
if s.Stmt.BindName(i) != "" {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
@@ -254,16 +245,16 @@ func (s stmt) NumInput() int {
|
||||
}
|
||||
|
||||
// Deprecated: use ExecContext instead.
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
// Deprecated: use QueryContext instead.
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
// Use QueryContext to setup bindings.
|
||||
// No need to close rows: that simply resets the statement, exec does the same.
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
@@ -271,19 +262,16 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.stmt.Exec()
|
||||
err = s.Stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result{
|
||||
int64(s.conn.LastInsertRowID()),
|
||||
int64(s.conn.Changes()),
|
||||
}, nil
|
||||
return newResult(s.Conn), nil
|
||||
}
|
||||
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.stmt.ClearBindings()
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.Stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -295,7 +283,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
ids = append(ids, arg.Ordinal)
|
||||
} else {
|
||||
for _, prefix := range []string{":", "@", "$"} {
|
||||
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
@@ -304,23 +292,23 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
for _, id := range ids {
|
||||
switch a := arg.Value.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(id, a)
|
||||
err = s.Stmt.BindBool(id, a)
|
||||
case int:
|
||||
err = s.stmt.BindInt(id, a)
|
||||
err = s.Stmt.BindInt(id, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(id, a)
|
||||
err = s.Stmt.BindInt64(id, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(id, a)
|
||||
err = s.Stmt.BindFloat(id, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(id, a)
|
||||
err = s.Stmt.BindText(id, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(id, a)
|
||||
err = s.Stmt.BindBlob(id, a)
|
||||
case sqlite3.ZeroBlob:
|
||||
err = s.stmt.BindZeroBlob(id, int64(a))
|
||||
err = s.Stmt.BindZeroBlob(id, int64(a))
|
||||
case time.Time:
|
||||
err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
case nil:
|
||||
err = s.stmt.BindNull(id)
|
||||
err = s.Stmt.BindNull(id)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
@@ -330,10 +318,10 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
}
|
||||
}
|
||||
|
||||
return rows{ctx, s.stmt, s.conn}, nil
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
}
|
||||
|
||||
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, nil:
|
||||
@@ -343,6 +331,17 @@ func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
}
|
||||
}
|
||||
|
||||
func newResult(c *sqlite3.Conn) driver.Result {
|
||||
rows := c.Changes()
|
||||
if rows != 0 {
|
||||
id := c.LastInsertRowID()
|
||||
if id != 0 {
|
||||
return result{id, rows}
|
||||
}
|
||||
}
|
||||
return resultRowsAffected(rows)
|
||||
}
|
||||
|
||||
type result struct{ lastInsertId, rowsAffected int64 }
|
||||
|
||||
func (r result) LastInsertId() (int64, error) {
|
||||
@@ -353,46 +352,56 @@ func (r result) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, nil
|
||||
}
|
||||
|
||||
type resultRowsAffected int64
|
||||
|
||||
func (r resultRowsAffected) LastInsertId() (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r resultRowsAffected) RowsAffected() (int64, error) {
|
||||
return int64(r), nil
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
ctx context.Context
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
func (r rows) Close() error {
|
||||
return r.stmt.Reset()
|
||||
func (r *rows) Close() error {
|
||||
return r.Stmt.Reset()
|
||||
}
|
||||
|
||||
func (r rows) Columns() []string {
|
||||
count := r.stmt.ColumnCount()
|
||||
func (r *rows) Columns() []string {
|
||||
count := r.Stmt.ColumnCount()
|
||||
columns := make([]string, count)
|
||||
for i := range columns {
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
columns[i] = r.Stmt.ColumnName(i)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (r rows) Next(dest []driver.Value) error {
|
||||
old := r.conn.SetInterrupt(r.ctx)
|
||||
defer r.conn.SetInterrupt(old)
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
old := r.Conn.SetInterrupt(r.ctx)
|
||||
defer r.Conn.SetInterrupt(old)
|
||||
|
||||
if !r.stmt.Step() {
|
||||
if err := r.stmt.Err(); err != nil {
|
||||
if !r.Stmt.Step() {
|
||||
if err := r.Stmt.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch r.stmt.ColumnType(i) {
|
||||
switch r.Stmt.ColumnType(i) {
|
||||
case sqlite3.INTEGER:
|
||||
dest[i] = r.stmt.ColumnInt64(i)
|
||||
dest[i] = r.Stmt.ColumnInt64(i)
|
||||
case sqlite3.FLOAT:
|
||||
dest[i] = r.stmt.ColumnFloat(i)
|
||||
dest[i] = r.Stmt.ColumnFloat(i)
|
||||
case sqlite3.BLOB:
|
||||
dest[i] = r.stmt.ColumnRawBlob(i)
|
||||
dest[i] = r.Stmt.ColumnRawBlob(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = stringOrTime(r.stmt.ColumnRawText(i))
|
||||
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
@@ -404,5 +413,5 @@ func (r rows) Next(dest []driver.Value) error {
|
||||
}
|
||||
}
|
||||
|
||||
return r.stmt.Err()
|
||||
return r.Stmt.Err()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Embeddable WASM build of SQLite
|
||||
|
||||
This folder includes an embeddable WASM build of SQLite 3.41.2 for use with
|
||||
This folder includes an embeddable WASM build of SQLite 3.42.0 for use with
|
||||
[`github.com/ncruces/go-sqlite3`](https://pkg.go.dev/github.com/ncruces/go-sqlite3).
|
||||
|
||||
The following optional features are compiled in:
|
||||
@@ -17,7 +17,8 @@ The following optional features are compiled in:
|
||||
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
|
||||
- [time](../sqlite3/time.c)
|
||||
|
||||
See the [configuration options](../sqlite3/sqlite_cfg.h).
|
||||
See the [configuration options](../sqlite3/sqlite_cfg.h),
|
||||
and [patches](../sqlite3) applied.
|
||||
|
||||
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
|
||||
and [`binaryen`](https://github.com/WebAssembly/binaryen).
|
||||
@@ -4,16 +4,17 @@ set -euo pipefail
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
ROOT=../
|
||||
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
|
||||
BINARYEN="$ROOT/tools/binaryen-version_114/bin"
|
||||
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \
|
||||
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \
|
||||
-I"$ROOT/sqlite3" \
|
||||
-mexec-model=reactor \
|
||||
-mmutable-globals \
|
||||
-msimd128 -mmutable-globals \
|
||||
-mbulk-memory -mreference-types \
|
||||
-mnontrapping-fptoint -msign-ext \
|
||||
-fno-stack-protector -fno-stack-clash-protection \
|
||||
-Wl,--initial-memory=327680 \
|
||||
-Wl,--stack-first \
|
||||
-Wl,--import-undefined \
|
||||
@@ -22,7 +23,8 @@ WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
|
||||
|
||||
trap 'rm -f sqlite3.tmp' EXIT
|
||||
"$BINARYEN/wasm-ctor-eval" -g -c _initialize sqlite3.wasm -o sqlite3.tmp
|
||||
"$BINARYEN/wasm-opt" -g -O2 sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-multivalue --enable-mutable-globals \
|
||||
"$BINARYEN/wasm-opt" -g --strip -c -O3 \
|
||||
sqlite3.tmp -o sqlite3.wasm \
|
||||
--enable-simd --enable-mutable-globals --enable-multivalue \
|
||||
--enable-bulk-memory --enable-reference-types \
|
||||
--enable-nontrapping-float-to-int --enable-sign-ext
|
||||
@@ -33,16 +33,42 @@ sqlite3_column_blob
|
||||
sqlite3_column_bytes
|
||||
sqlite3_blob_open
|
||||
sqlite3_blob_close
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_blob_bytes
|
||||
sqlite3_blob_read
|
||||
sqlite3_blob_write
|
||||
sqlite3_blob_reopen
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_last_insert_rowid
|
||||
sqlite3_changes64
|
||||
sqlite3_backup_init
|
||||
sqlite3_backup_step
|
||||
sqlite3_backup_finish
|
||||
sqlite3_backup_remaining
|
||||
sqlite3_backup_pagecount
|
||||
sqlite3_interrupt_offset
|
||||
sqlite3_uri_parameter
|
||||
sqlite3_uri_key
|
||||
sqlite3_changes64
|
||||
sqlite3_last_insert_rowid
|
||||
sqlite3_get_autocommit
|
||||
sqlite3_anycollseq_init
|
||||
sqlite3_create_collation_go
|
||||
sqlite3_create_function_go
|
||||
sqlite3_create_aggregate_function_go
|
||||
sqlite3_create_window_function_go
|
||||
sqlite3_aggregate_context
|
||||
sqlite3_user_data
|
||||
sqlite3_set_auxdata_go
|
||||
sqlite3_get_auxdata
|
||||
sqlite3_value_type
|
||||
sqlite3_value_int64
|
||||
sqlite3_value_double
|
||||
sqlite3_value_text
|
||||
sqlite3_value_blob
|
||||
sqlite3_value_bytes
|
||||
sqlite3_result_null
|
||||
sqlite3_result_int64
|
||||
sqlite3_result_double
|
||||
sqlite3_result_text64
|
||||
sqlite3_result_blob64
|
||||
sqlite3_result_zeroblob64
|
||||
sqlite3_result_error
|
||||
sqlite3_result_error_code
|
||||
sqlite3_result_error_nomem
|
||||
sqlite3_result_error_toobig
|
||||
Binary file not shown.
103
error.go
103
error.go
@@ -3,6 +3,8 @@ package sqlite3
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Error wraps an SQLite Error Code.
|
||||
@@ -66,6 +68,19 @@ func (e *Error) Is(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode] or [ExtendedErrorCode].
|
||||
func (e *Error) As(err any) bool {
|
||||
switch c := err.(type) {
|
||||
case *ErrorCode:
|
||||
*c = e.Code()
|
||||
return true
|
||||
case *ExtendedErrorCode:
|
||||
*c = e.ExtendedCode()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e *Error) Temporary() bool {
|
||||
return e.Code() == BUSY
|
||||
@@ -83,72 +98,7 @@ func (e *Error) SQL() string {
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ErrorCode) Error() string {
|
||||
switch e {
|
||||
case _OK:
|
||||
return "sqlite3: not an error"
|
||||
case _ROW:
|
||||
return "sqlite3: another row available"
|
||||
case _DONE:
|
||||
return "sqlite3: no more rows available"
|
||||
|
||||
case ERROR:
|
||||
return "sqlite3: SQL logic error"
|
||||
case INTERNAL:
|
||||
break
|
||||
case PERM:
|
||||
return "sqlite3: access permission denied"
|
||||
case ABORT:
|
||||
return "sqlite3: query aborted"
|
||||
case BUSY:
|
||||
return "sqlite3: database is locked"
|
||||
case LOCKED:
|
||||
return "sqlite3: database table is locked"
|
||||
case NOMEM:
|
||||
return "sqlite3: out of memory"
|
||||
case READONLY:
|
||||
return "sqlite3: attempt to write a readonly database"
|
||||
case INTERRUPT:
|
||||
return "sqlite3: interrupted"
|
||||
case IOERR:
|
||||
return "sqlite3: disk I/O error"
|
||||
case CORRUPT:
|
||||
return "sqlite3: database disk image is malformed"
|
||||
case NOTFOUND:
|
||||
return "sqlite3: unknown operation"
|
||||
case FULL:
|
||||
return "sqlite3: database or disk is full"
|
||||
case CANTOPEN:
|
||||
return "sqlite3: unable to open database file"
|
||||
case PROTOCOL:
|
||||
return "sqlite3: locking protocol"
|
||||
case FORMAT:
|
||||
break
|
||||
case SCHEMA:
|
||||
return "sqlite3: database schema has changed"
|
||||
case TOOBIG:
|
||||
return "sqlite3: string or blob too big"
|
||||
case CONSTRAINT:
|
||||
return "sqlite3: constraint failed"
|
||||
case MISMATCH:
|
||||
return "sqlite3: datatype mismatch"
|
||||
case MISUSE:
|
||||
return "sqlite3: bad parameter or other API misuse"
|
||||
case NOLFS:
|
||||
break
|
||||
case AUTH:
|
||||
return "sqlite3: authorization denied"
|
||||
case EMPTY:
|
||||
break
|
||||
case RANGE:
|
||||
return "sqlite3: column index out of range"
|
||||
case NOTADB:
|
||||
return "sqlite3: file is not a database"
|
||||
case NOTICE:
|
||||
return "sqlite3: notification message"
|
||||
case WARNING:
|
||||
return "sqlite3: warning message"
|
||||
}
|
||||
return "sqlite3: unknown error"
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
@@ -158,17 +108,7 @@ func (e ErrorCode) Temporary() bool {
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ExtendedErrorCode) Error() string {
|
||||
switch x := ErrorCode(e); {
|
||||
case e == ABORT_ROLLBACK:
|
||||
return "sqlite3: abort due to ROLLBACK"
|
||||
case x < _ROW:
|
||||
return x.Error()
|
||||
case e == _ROW:
|
||||
return "sqlite3: another row available"
|
||||
case e == _DONE:
|
||||
return "sqlite3: no more rows available"
|
||||
}
|
||||
return "sqlite3: unknown error"
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
}
|
||||
|
||||
// Is tests whether this error matches a given [ErrorCode].
|
||||
@@ -177,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool {
|
||||
return ok && c == ErrorCode(e)
|
||||
}
|
||||
|
||||
// As converts this error to an [ErrorCode].
|
||||
func (e ExtendedErrorCode) As(err any) bool {
|
||||
c, ok := err.(*ErrorCode)
|
||||
if ok {
|
||||
*c = ErrorCode(e)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// Temporary returns true for [BUSY] errors.
|
||||
func (e ExtendedErrorCode) Temporary() bool {
|
||||
return ErrorCode(e) == BUSY
|
||||
|
||||
@@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) {
|
||||
func TestError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := Error{code: 0x8080}
|
||||
if rc := err.Code(); rc != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", rc)
|
||||
var ecode ErrorCode
|
||||
var xcode xErrorCode
|
||||
err := &Error{code: 0x8080}
|
||||
if !errors.As(err, &err) {
|
||||
t.Fatal("want true")
|
||||
}
|
||||
if !errors.Is(&err, ErrorCode(0x80)) {
|
||||
if ecode := err.Code(); ecode != 0x80 {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err, ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if rc := err.ExtendedCode(); rc != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", rc)
|
||||
if xcode := err.ExtendedCode(); xcode != 0x8080 {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(&err, ExtendedErrorCode(0x8080)) {
|
||||
if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) {
|
||||
t.Errorf("got %#x, want 0x8080", uint16(xcode))
|
||||
}
|
||||
if !errors.Is(err, xErrorCode(0x8080)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
if s := err.Error(); s != "sqlite3: 32896" {
|
||||
t.Errorf("got %q", s)
|
||||
}
|
||||
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
|
||||
t.Errorf("got %#x, want 0x80", uint8(ecode))
|
||||
}
|
||||
if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) {
|
||||
t.Errorf("want true")
|
||||
}
|
||||
@@ -122,7 +136,7 @@ func Test_ErrorCode_Error(t *testing.T) {
|
||||
for i := 0; i == int(ErrorCode(i)); i++ {
|
||||
want := "sqlite3: "
|
||||
r := db.call(db.api.errstr, uint64(i))
|
||||
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
|
||||
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
|
||||
|
||||
got := ErrorCode(i).Error()
|
||||
if got != want {
|
||||
@@ -144,7 +158,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
|
||||
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
|
||||
want := "sqlite3: "
|
||||
r := db.call(db.api.errstr, uint64(i))
|
||||
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
|
||||
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
|
||||
|
||||
got := ExtendedErrorCode(i).Error()
|
||||
if got != want {
|
||||
|
||||
186
func.go
Normal file
186
func.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
// AnyCollationNeeded registers a fake collating function
|
||||
// for any unknown collating sequence.
|
||||
// The fake collating function works like BINARY.
|
||||
//
|
||||
// This extension can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c *Conn) AnyCollationNeeded() {
|
||||
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
|
||||
}
|
||||
|
||||
// CreateCollation defines a new collating sequence.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_collation.html
|
||||
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createCollation,
|
||||
uint64(c.handle), uint64(namePtr), uint64(funcPtr))
|
||||
if err := c.error(r); err != nil {
|
||||
util.DelHandle(c.ctx, funcPtr)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFunction defines a new scalar SQL function.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
r := c.call(c.api.createFunction,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
||||
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/create_function.html
|
||||
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
|
||||
call := c.api.createAggregate
|
||||
namePtr := c.arena.string(name)
|
||||
funcPtr := util.AddHandle(c.ctx, fn)
|
||||
if _, ok := fn().(WindowFunction); ok {
|
||||
call = c.api.createWindow
|
||||
}
|
||||
r := c.call(call,
|
||||
uint64(c.handle), uint64(namePtr), uint64(nArg),
|
||||
uint64(flag), uint64(funcPtr))
|
||||
return c.error(r)
|
||||
}
|
||||
|
||||
// AggregateFunction is the interface an aggregate function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/appfunc.html
|
||||
type AggregateFunction interface {
|
||||
// Step is invoked to add a row to the current window.
|
||||
// The function arguments, if any, corresponding to the row being added are passed to Step.
|
||||
Step(ctx Context, arg ...Value)
|
||||
|
||||
// Value is invoked to return the current value of the aggregate.
|
||||
Value(ctx Context)
|
||||
}
|
||||
|
||||
// WindowFunction is the interface an aggregate window function should implement.
|
||||
//
|
||||
// https://www.sqlite.org/windowfunctions.html
|
||||
type WindowFunction interface {
|
||||
AggregateFunction
|
||||
|
||||
// Inverse is invoked to remove the oldest presently aggregated result of Step from the current window.
|
||||
// The function arguments, if any, are those passed to Step for the row being removed.
|
||||
Inverse(ctx Context, arg ...Value)
|
||||
}
|
||||
|
||||
func exportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.ExportFuncVI(env, "go_destroy", callbackDestroy)
|
||||
util.ExportFuncIIIIII(env, "go_compare", callbackCompare)
|
||||
util.ExportFuncVIII(env, "go_func", callbackFunc)
|
||||
util.ExportFuncVIII(env, "go_step", callbackStep)
|
||||
util.ExportFuncVI(env, "go_final", callbackFinal)
|
||||
util.ExportFuncVI(env, "go_value", callbackValue)
|
||||
util.ExportFuncVIII(env, "go_inverse", callbackInverse)
|
||||
return env
|
||||
}
|
||||
|
||||
func callbackDestroy(ctx context.Context, mod api.Module, pApp uint32) {
|
||||
util.DelHandle(ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackCompare(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nKey2, pKey2 uint32) uint32 {
|
||||
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
|
||||
return uint32(fn(util.View(mod, pKey1, uint64(nKey1)), util.View(mod, pKey2, uint64(nKey2))))
|
||||
}
|
||||
|
||||
func callbackFunc(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackHandle(sqlite, pCtx).(func(ctx Context, arg ...Value))
|
||||
fn(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackStep(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Step(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackFinal(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
var handle uint32
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, &handle).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
if err := util.DelHandle(ctx, handle); err != nil {
|
||||
Context{sqlite, pCtx}.ResultError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackValue(ctx context.Context, mod api.Module, pCtx uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(AggregateFunction)
|
||||
fn.Value(Context{sqlite, pCtx})
|
||||
}
|
||||
|
||||
func callbackInverse(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
|
||||
sqlite := ctx.Value(sqliteKey{}).(*sqlite)
|
||||
fn := callbackAggregate(sqlite, pCtx, nil).(WindowFunction)
|
||||
fn.Inverse(Context{sqlite, pCtx}, callbackArgs(sqlite, nArg, pArg)...)
|
||||
}
|
||||
|
||||
func callbackHandle(sqlite *sqlite, pCtx uint32) any {
|
||||
pApp := uint32(sqlite.call(sqlite.api.userData, uint64(pCtx)))
|
||||
return util.GetHandle(sqlite.ctx, pApp)
|
||||
}
|
||||
|
||||
func callbackAggregate(sqlite *sqlite, pCtx uint32, close *uint32) any {
|
||||
// On close, we're getting rid of the handle.
|
||||
// Don't allocate space to store it.
|
||||
var size uint64
|
||||
if close == nil {
|
||||
size = ptrlen
|
||||
}
|
||||
ptr := uint32(sqlite.call(sqlite.api.aggregateCtx, uint64(pCtx), size))
|
||||
|
||||
// Try loading the handle, if we already have one, or want a new one.
|
||||
if ptr != 0 || size != 0 {
|
||||
if handle := util.ReadUint32(sqlite.mod, ptr); handle != 0 {
|
||||
fn := util.GetHandle(sqlite.ctx, handle)
|
||||
if close != nil {
|
||||
*close = handle
|
||||
}
|
||||
if fn != nil {
|
||||
return fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new aggregate and store the handle.
|
||||
fn := callbackHandle(sqlite, pCtx).(func() AggregateFunction)()
|
||||
if ptr != 0 {
|
||||
util.WriteUint32(sqlite.mod, ptr, util.AddHandle(sqlite.ctx, fn))
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func callbackArgs(sqlite *sqlite, nArg, pArg uint32) []Value {
|
||||
args := make([]Value, nArg)
|
||||
for i := range args {
|
||||
args[i] = Value{
|
||||
sqlite: sqlite,
|
||||
handle: util.ReadUint32(sqlite.mod, pArg+ptrlen*uint32(i)),
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
181
func_test.go
Normal file
181
func_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
|
||||
"golang.org/x/text/collate"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateCollation() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateCollation("french", collate.New(language.French).Compare)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE french`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// cote
|
||||
// coté
|
||||
// côte
|
||||
// côté
|
||||
// cotée
|
||||
// coter
|
||||
}
|
||||
|
||||
func ExampleConn_CreateFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("upper", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT upper(word) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// COTE
|
||||
// COTÉ
|
||||
// CÔTE
|
||||
// CÔTÉ
|
||||
// COTÉE
|
||||
// COTER
|
||||
}
|
||||
|
||||
func ExampleContext_SetAuxData() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateFunction("regexp", 2, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
}
|
||||
ctx.SetAuxData(0, r)
|
||||
re = r
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawText()))
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT word FROM words WHERE word REGEXP '^\p{L}+e$'`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnText(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Unordered output:
|
||||
// cote
|
||||
// côte
|
||||
// cotée
|
||||
}
|
||||
96
func_win_test.go
Normal file
96
func_win_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"unicode"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func ExampleConn_CreateWindowFunction() {
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS words (word VARCHAR(10))`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT count_ascii(word) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM words`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for stmt.Step() {
|
||||
fmt.Println(stmt.ColumnInt(0))
|
||||
}
|
||||
if err := stmt.Err(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// Output:
|
||||
// 1
|
||||
// 2
|
||||
// 2
|
||||
// 1
|
||||
// 0
|
||||
// 0
|
||||
}
|
||||
|
||||
type countASCII struct{ result int }
|
||||
|
||||
func newASCIICounter() sqlite3.AggregateFunction {
|
||||
return &countASCII{}
|
||||
}
|
||||
|
||||
func (f *countASCII) Value(ctx sqlite3.Context) {
|
||||
ctx.ResultInt(f.result)
|
||||
}
|
||||
|
||||
func (f *countASCII) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result++
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if f.isASCII(arg[0]) {
|
||||
f.result--
|
||||
}
|
||||
}
|
||||
|
||||
func (f *countASCII) isASCII(arg sqlite3.Value) bool {
|
||||
if arg.Type() != sqlite3.TEXT {
|
||||
return false
|
||||
}
|
||||
for _, c := range arg.RawText() {
|
||||
if c > unicode.MaxASCII {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
8
go.mod
8
go.mod
@@ -4,9 +4,11 @@ go 1.19
|
||||
|
||||
require (
|
||||
github.com/ncruces/julianday v0.1.5
|
||||
github.com/tetratelabs/wazero v1.0.3
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/sys v0.7.0
|
||||
github.com/psanford/httpreadat v0.1.0
|
||||
github.com/tetratelabs/wazero v1.2.1
|
||||
golang.org/x/sync v0.3.0
|
||||
golang.org/x/sys v0.10.0
|
||||
golang.org/x/text v0.10.0
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
|
||||
16
go.sum
16
go.sum
@@ -1,8 +1,12 @@
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.0.3 h1:IWmaxc/5vKg71DE+c0SLjjLFAA3u3tD/Zegpgif2Wpo=
|
||||
github.com/tetratelabs/wazero v1.0.3/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
|
||||
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
|
||||
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
|
||||
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
|
||||
4
go.work.sum
Normal file
4
go.work.sum
Normal file
@@ -0,0 +1,4 @@
|
||||
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
22
gormlite/LICENSE
Normal file
22
gormlite/LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Nuno Cruces
|
||||
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
26
gormlite/README.md
Normal file
26
gormlite/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# GORM SQLite Driver
|
||||
|
||||
[](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/gormlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
|
||||
### Foreign-key constraint activation
|
||||
|
||||
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
|
||||
```go
|
||||
db, err := gorm.Open(gormlite.Open(
|
||||
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=foreign_keys(1)"),
|
||||
&gorm.Config{})
|
||||
```
|
||||
231
gormlite/ddlmod.go
Normal file
231
gormlite/ddlmod.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteSeparator = "`|\"|'|\t"
|
||||
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
|
||||
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
|
||||
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
|
||||
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
|
||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
||||
)
|
||||
|
||||
func getAllColumns(s string) []string {
|
||||
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
|
||||
columns := make([]string, 0, len(allMatches))
|
||||
for _, matches := range allMatches {
|
||||
if len(matches) > 1 {
|
||||
columns = append(columns, matches[1])
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
type ddl struct {
|
||||
head string
|
||||
fields []string
|
||||
columns []migrator.ColumnType
|
||||
}
|
||||
|
||||
func parseDDL(strs ...string) (*ddl, error) {
|
||||
var result ddl
|
||||
for _, str := range strs {
|
||||
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
|
||||
var (
|
||||
ddlBody = sections[2]
|
||||
ddlBodyRunes = []rune(ddlBody)
|
||||
bracketLevel int
|
||||
quote rune
|
||||
buf string
|
||||
)
|
||||
ddlBodyRunesLen := len(ddlBodyRunes)
|
||||
|
||||
result.head = sections[1]
|
||||
|
||||
for idx := 0; idx < ddlBodyRunesLen; idx++ {
|
||||
var (
|
||||
next rune = 0
|
||||
c = ddlBodyRunes[idx]
|
||||
)
|
||||
if idx+1 < ddlBodyRunesLen {
|
||||
next = ddlBodyRunes[idx+1]
|
||||
}
|
||||
|
||||
if sc := string(c); separatorRegexp.MatchString(sc) {
|
||||
if c == next {
|
||||
buf += sc // Skip escaped quote
|
||||
idx++
|
||||
} else if quote > 0 {
|
||||
quote = 0
|
||||
} else {
|
||||
quote = c
|
||||
}
|
||||
} else if quote == 0 {
|
||||
if c == '(' {
|
||||
bracketLevel++
|
||||
} else if c == ')' {
|
||||
bracketLevel--
|
||||
} else if bracketLevel == 0 {
|
||||
if c == ',' {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
buf = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bracketLevel < 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
buf += string(c)
|
||||
}
|
||||
|
||||
if bracketLevel != 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
if buf != "" {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
}
|
||||
|
||||
for _, f := range result.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
|
||||
for _, name := range getAllColumns(f) {
|
||||
for idx, column := range result.columns {
|
||||
if column.NameValue.String == name {
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = column
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
|
||||
columnType := migrator.ColumnType{
|
||||
NameValue: sql.NullString{String: matches[1], Valid: true},
|
||||
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
matchUpper := strings.ToUpper(matches[3])
|
||||
if strings.Contains(matchUpper, " NOT NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
|
||||
} else if strings.Contains(matchUpper, " NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " UNIQUE") {
|
||||
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " PRIMARY") {
|
||||
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
|
||||
if strings.ToLower(defaultMatches[1]) != "null" {
|
||||
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
|
||||
}
|
||||
}
|
||||
|
||||
// data type length
|
||||
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
|
||||
if len(matches) == 1 && len(matches[0]) == 2 {
|
||||
size, _ := strconv.Atoi(matches[0][1])
|
||||
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
|
||||
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
|
||||
}
|
||||
|
||||
result.columns = append(result.columns, columnType)
|
||||
}
|
||||
}
|
||||
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
|
||||
for _, column := range getAllColumns(matches[1]) {
|
||||
for idx, c := range result.columns {
|
||||
if c.NameValue.String == column {
|
||||
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
|
||||
result.columns[idx] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("invalid DDL")
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
}
|
||||
|
||||
func (d *ddl) removeConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
for _, f := range d.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
|
||||
strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
|
||||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
|
||||
continue
|
||||
}
|
||||
|
||||
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
|
||||
match := reg.FindStringSubmatch(f)
|
||||
|
||||
if match != nil {
|
||||
res = append(res, "`"+match[1]+"`")
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
352
gormlite/ddlmod_test.go
Normal file
352
gormlite/ddlmod_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestParseDDL(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{"with_fk", []string{
|
||||
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
|
||||
}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
}},
|
||||
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"no brackets", []string{"create table test"}, 0, nil},
|
||||
{"with_special_characters", []string{
|
||||
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
|
||||
}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"table_name_with_dash",
|
||||
[]string{
|
||||
"CREATE TABLE `test-a` (`id` int NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"non-unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
|
||||
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
tests.AssertEqual(t, p.sql[0], ddl.compile())
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
testColumns := []migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
|
||||
},
|
||||
{
|
||||
NameValue: sql.NullString{String: "dark_mode", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
DefaultValueValue: sql.NullString{String: "true", Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{
|
||||
"with_newline",
|
||||
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_newline_2",
|
||||
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_missing_space",
|
||||
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_many_spaces",
|
||||
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
}
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_error(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql string
|
||||
}{
|
||||
{"invalid_cmd", "CREATE TABLE"},
|
||||
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
|
||||
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
_, err := parseDDL(p.sql)
|
||||
if err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
sql string
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "add_new",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
},
|
||||
{
|
||||
name: "update",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
},
|
||||
{
|
||||
name: "add_check",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
{
|
||||
name: "update_check",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
testDDL.addConstraint(p.cName, p.sql)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
success bool
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "fk",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "check",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "nothing",
|
||||
success: false,
|
||||
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
success := testDDL.removeConstraint(p.cName)
|
||||
|
||||
tests.AssertEqual(t, p.success, success)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumns(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
ddl string
|
||||
columns []string
|
||||
}{
|
||||
{
|
||||
name: "with_fk",
|
||||
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
columns: []string{"`id`", "`text`", "`user_id`"},
|
||||
},
|
||||
{
|
||||
name: "with_check",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
|
||||
},
|
||||
{
|
||||
name: "with_escaped_quote",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_generated_column",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_new_line",
|
||||
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
|
||||
"id" integer PRIMARY KEY AUTOINCREMENT,
|
||||
"created_at" datetime NOT NULL,
|
||||
"updated_at" datetime NOT NULL,
|
||||
"created_by" integer NOT NULL DEFAULT 0,
|
||||
"updated_by" integer NOT NULL DEFAULT 0,
|
||||
"role_id" integer NOT NULL,
|
||||
"menu_id" bigint NOT NULL
|
||||
)`,
|
||||
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL, err := parseDDL(p.ddl)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
cols := testDDL.getColumns()
|
||||
|
||||
tests.AssertEqual(t, p.columns, cols)
|
||||
})
|
||||
}
|
||||
}
|
||||
11
gormlite/download.sh
Executable file
11
gormlite/download.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
|
||||
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"
|
||||
21
gormlite/error_translator.go
Normal file
21
gormlite/error_translator.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
switch {
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
|
||||
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
|
||||
return gorm.ErrDuplicatedKey
|
||||
case
|
||||
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
|
||||
return gorm.ErrForeignKeyViolated
|
||||
}
|
||||
return err
|
||||
}
|
||||
16
gormlite/go.mod
Normal file
16
gormlite/go.mod
Normal file
@@ -0,0 +1,16 @@
|
||||
module github.com/ncruces/go-sqlite3/gormlite
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/ncruces/go-sqlite3 v0.8.1
|
||||
gorm.io/gorm v1.25.2
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/ncruces/julianday v0.1.5 // indirect
|
||||
github.com/tetratelabs/wazero v1.2.1 // indirect
|
||||
golang.org/x/sys v0.9.0 // indirect
|
||||
)
|
||||
14
gormlite/go.sum
Normal file
14
gormlite/go.sum
Normal file
@@ -0,0 +1,14 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/ncruces/go-sqlite3 v0.8.1 h1:e1Y7uHu96xC4fWKsCVWprbTi8vAaQX9R+8kgkxOHWaY=
|
||||
github.com/ncruces/go-sqlite3 v0.8.1/go.mod h1:EhHe1qvG6Zc/8ffYMzre8n//rTRs1YNN5dUD1f1mEGc=
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
|
||||
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
|
||||
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
431
gormlite/migrator.go
Normal file
431
gormlite/migrator.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
m.DB.Exec("PRAGMA foreign_keys = OFF")
|
||||
defer m.DB.Exec("PRAGMA foreign_keys = ON")
|
||||
}
|
||||
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
|
||||
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
|
||||
|
||||
if createSQL == rawDDL {
|
||||
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
|
||||
}
|
||||
|
||||
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
}
|
||||
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
sqls []string
|
||||
sqlDDL *ddl
|
||||
)
|
||||
|
||||
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlDDL, err = parseDDL(sqls...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnType := migrator.ColumnType{SQLColumnType: c}
|
||||
for _, column := range sqlDDL.columns {
|
||||
if column.NameValue.String == c.Name() {
|
||||
column.SQLColumnType = c
|
||||
columnType = column
|
||||
break
|
||||
}
|
||||
}
|
||||
columnTypes = append(columnTypes, columnType)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, "")
|
||||
|
||||
return createSQL, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
constraintValues []interface{}
|
||||
)
|
||||
|
||||
if constraint != nil {
|
||||
constraintName = constraint.Name
|
||||
constraintSql, constraintValues = buildConstraint(constraint)
|
||||
} else if chk != nil {
|
||||
constraintName = chk.Name
|
||||
constraintSql = "CONSTRAINT ? CHECK (?)"
|
||||
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
} else {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.addConstraint(constraintName, constraintSql)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.removeConstraint(name)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
if err := m.DropIndex(value, oldName); err != nil {
|
||||
return err
|
||||
}
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
if m.DB.Error != nil {
|
||||
return "", m.DB.Error
|
||||
}
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
table = *tablePtr
|
||||
}
|
||||
|
||||
rawDDL, err := m.getRawDDL(table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
|
||||
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createSQL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
|
||||
createDDL, err := parseDDL(createSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
columns := createDDL.getColumns()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{
|
||||
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
|
||||
fmt.Sprintf("DROP TABLE `%v`", table),
|
||||
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
|
||||
}
|
||||
for _, query := range queries {
|
||||
if err := tx.Exec(query).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
219
gormlite/sqlite.go
Normal file
219
gormlite/sqlite.go
Normal file
@@ -0,0 +1,219 @@
|
||||
// Package gormlite provides a GORM driver for SQLite.
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := sql.Open("sqlite3", dialector.DSN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.ConnPool = conn
|
||||
}
|
||||
|
||||
var version string
|
||||
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
|
||||
return err
|
||||
}
|
||||
// https://www.sqlite.org/releaselog/3_35_0.html
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
} else {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
stmt.WriteString("INSERT ")
|
||||
if insert.Modifier != "" {
|
||||
stmt.WriteString(insert.Modifier)
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
stmt.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
stmt.WriteQuoted(stmt.Table)
|
||||
} else {
|
||||
stmt.WriteQuoted(insert.Table)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
},
|
||||
"LIMIT": func(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
var lmt = -1
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
lmt = *limit.Limit
|
||||
}
|
||||
if lmt >= 0 || limit.Offset > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(lmt))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
},
|
||||
"FOR": func(c clause.Clause, builder clause.Builder) {
|
||||
if _, ok := c.Expression.(clause.Locking); ok {
|
||||
// SQLite3 does not support row-level locking.
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
|
||||
// doesn't work, will raise error
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('`')
|
||||
if strings.Contains(str, ".") {
|
||||
for idx, str := range strings.Split(str, ".") {
|
||||
if idx > 0 {
|
||||
writer.WriteString(".`")
|
||||
}
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
} else {
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement && !field.PrimaryKey {
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "real"
|
||||
case schema.String:
|
||||
return "text"
|
||||
case schema.Time:
|
||||
// Distinguish between schema.Time and tag time
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
return val
|
||||
} else {
|
||||
return "datetime"
|
||||
}
|
||||
case schema.Bytes:
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareVersion(version1, version2 string) int {
|
||||
n, m := len(version1), len(version2)
|
||||
i, j := 0, 0
|
||||
for i < n || j < m {
|
||||
x := 0
|
||||
for ; i < n && version1[i] != '.'; i++ {
|
||||
x = x*10 + int(version1[i]-'0')
|
||||
}
|
||||
i++
|
||||
y := 0
|
||||
for ; j < m && version2[j] != '.'; j++ {
|
||||
y = y*10 + int(version2[j]-'0')
|
||||
}
|
||||
j++
|
||||
if x > y {
|
||||
return 1
|
||||
}
|
||||
if x < y {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
64
gormlite/sqlite_test.go
Normal file
64
gormlite/sqlite_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDialector(t *testing.T) {
|
||||
// This is the DSN of the in-memory SQLite database for these tests.
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector *Dialector
|
||||
openSuccess bool
|
||||
query string
|
||||
querySuccess bool
|
||||
}{
|
||||
{
|
||||
description: "Default driver",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
}
|
||||
for rowIndex, row := range rows {
|
||||
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
|
||||
db, err := gorm.Open(row.dialector, &gorm.Config{})
|
||||
if !row.openSuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected Open to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected Open to succeed; got error: %v", err)
|
||||
}
|
||||
if db == nil {
|
||||
t.Errorf("Expected db to be non-nil.")
|
||||
}
|
||||
if row.query != "" {
|
||||
err = db.Exec(row.query).Error
|
||||
if !row.querySuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected query to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected query to succeed; got error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
18
gormlite/test.sh
Executable file
18
gormlite/test.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
rm -rf gorm/ tests/ $TMPDIR/gorm.db
|
||||
git clone --filter=blob:none https://github.com/go-gorm/gorm.git
|
||||
mv gorm/tests tests
|
||||
rm -rf gorm/
|
||||
|
||||
patch -p1 -N < tests.patch
|
||||
|
||||
cd tests
|
||||
go mod tidy && go work use . && go test
|
||||
|
||||
cd ..
|
||||
rm -rf tests/ $TMPDIR/gorm.db
|
||||
go work use -r .
|
||||
42
gormlite/tests.patch
Normal file
42
gormlite/tests.patch
Normal file
@@ -0,0 +1,42 @@
|
||||
diff --git a/tests/.gitignore b/tests/.gitignore
|
||||
--- a/tests/.gitignore
|
||||
+++ b/tests/.gitignore
|
||||
@@ -1 +1 @@
|
||||
-go.sum
|
||||
+*
|
||||
diff --git a/tests/go.mod b/tests/go.mod
|
||||
--- a/tests/go.mod
|
||||
+++ b/tests/go.mod
|
||||
@@ -7,12 +7,12 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jinzhu/now v1.1.5
|
||||
github.com/lib/pq v1.10.8
|
||||
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
|
||||
+ github.com/ncruces/go-sqlite3 v0.7.2
|
||||
+ github.com/ncruces/go-sqlite3/gormlite v0.0.0
|
||||
gorm.io/driver/mysql v1.5.0
|
||||
gorm.io/driver/postgres v1.5.0
|
||||
- gorm.io/driver/sqlite v1.5.0
|
||||
gorm.io/driver/sqlserver v1.5.1
|
||||
- gorm.io/gorm v1.25.1
|
||||
+ gorm.io/gorm v1.25.2
|
||||
)
|
||||
|
||||
-replace gorm.io/gorm => ../
|
||||
+replace github.com/ncruces/go-sqlite3/gormlite => ../
|
||||
diff --git a/tests/tests_test.go b/tests/tests_test.go
|
||||
--- a/tests/tests_test.go
|
||||
+++ b/tests/tests_test.go
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
+ _ "github.com/ncruces/go-sqlite3/embed"
|
||||
+ sqlite "github.com/ncruces/go-sqlite3/gormlite"
|
||||
+
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
- "gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
22
internal/util/bool.go
Normal file
22
internal/util/bool.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
func ParseBool(s string) (b, ok bool) {
|
||||
if len(s) == 0 {
|
||||
return false, false
|
||||
}
|
||||
if s[0] == '0' {
|
||||
return false, true
|
||||
}
|
||||
if '1' <= s[0] && s[0] <= '9' {
|
||||
return true, true
|
||||
}
|
||||
switch strings.ToLower(s) {
|
||||
case "true", "yes", "on":
|
||||
return true, true
|
||||
case "false", "no", "off":
|
||||
return false, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
28
internal/util/bool_test.go
Normal file
28
internal/util/bool_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
str string
|
||||
val bool
|
||||
ok bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"0", false, true},
|
||||
{"1", true, true},
|
||||
{"9", true, true},
|
||||
{"T", false, false},
|
||||
{"true", true, true},
|
||||
{"FALSE", false, true},
|
||||
{"false?", false, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.str, func(t *testing.T) {
|
||||
gotVal, gotOK := ParseBool(tt.str)
|
||||
if gotVal != tt.val || gotOK != tt.ok {
|
||||
t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
116
internal/util/const.go
Normal file
116
internal/util/const.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package util
|
||||
|
||||
// https://sqlite.com/matrix/rescode.html
|
||||
const (
|
||||
OK = 0 /* Successful result */
|
||||
|
||||
ERROR = 1 /* Generic error */
|
||||
INTERNAL = 2 /* Internal logic error in SQLite */
|
||||
PERM = 3 /* Access permission denied */
|
||||
ABORT = 4 /* Callback routine requested an abort */
|
||||
BUSY = 5 /* The database file is locked */
|
||||
LOCKED = 6 /* A table in the database is locked */
|
||||
NOMEM = 7 /* A malloc() failed */
|
||||
READONLY = 8 /* Attempt to write a readonly database */
|
||||
INTERRUPT = 9 /* Operation terminated by sqlite3_interrupt() */
|
||||
IOERR = 10 /* Some kind of disk I/O error occurred */
|
||||
CORRUPT = 11 /* The database disk image is malformed */
|
||||
NOTFOUND = 12 /* Unknown opcode in sqlite3_file_control() */
|
||||
FULL = 13 /* Insertion failed because database is full */
|
||||
CANTOPEN = 14 /* Unable to open the database file */
|
||||
PROTOCOL = 15 /* Database lock protocol error */
|
||||
EMPTY = 16 /* Internal use only */
|
||||
SCHEMA = 17 /* The database schema changed */
|
||||
TOOBIG = 18 /* String or BLOB exceeds size limit */
|
||||
CONSTRAINT = 19 /* Abort due to constraint violation */
|
||||
MISMATCH = 20 /* Data type mismatch */
|
||||
MISUSE = 21 /* Library used incorrectly */
|
||||
NOLFS = 22 /* Uses OS features not supported on host */
|
||||
AUTH = 23 /* Authorization denied */
|
||||
FORMAT = 24 /* Not used */
|
||||
RANGE = 25 /* 2nd parameter to sqlite3_bind out of range */
|
||||
NOTADB = 26 /* File opened that is not a database file */
|
||||
NOTICE = 27 /* Notifications from sqlite3_log() */
|
||||
WARNING = 28 /* Warnings from sqlite3_log() */
|
||||
|
||||
ROW = 100 /* sqlite3_step() has another row ready */
|
||||
DONE = 101 /* sqlite3_step() has finished executing */
|
||||
|
||||
ERROR_MISSING_COLLSEQ = ERROR | (1 << 8)
|
||||
ERROR_RETRY = ERROR | (2 << 8)
|
||||
ERROR_SNAPSHOT = ERROR | (3 << 8)
|
||||
IOERR_READ = IOERR | (1 << 8)
|
||||
IOERR_SHORT_READ = IOERR | (2 << 8)
|
||||
IOERR_WRITE = IOERR | (3 << 8)
|
||||
IOERR_FSYNC = IOERR | (4 << 8)
|
||||
IOERR_DIR_FSYNC = IOERR | (5 << 8)
|
||||
IOERR_TRUNCATE = IOERR | (6 << 8)
|
||||
IOERR_FSTAT = IOERR | (7 << 8)
|
||||
IOERR_UNLOCK = IOERR | (8 << 8)
|
||||
IOERR_RDLOCK = IOERR | (9 << 8)
|
||||
IOERR_DELETE = IOERR | (10 << 8)
|
||||
IOERR_BLOCKED = IOERR | (11 << 8)
|
||||
IOERR_NOMEM = IOERR | (12 << 8)
|
||||
IOERR_ACCESS = IOERR | (13 << 8)
|
||||
IOERR_CHECKRESERVEDLOCK = IOERR | (14 << 8)
|
||||
IOERR_LOCK = IOERR | (15 << 8)
|
||||
IOERR_CLOSE = IOERR | (16 << 8)
|
||||
IOERR_DIR_CLOSE = IOERR | (17 << 8)
|
||||
IOERR_SHMOPEN = IOERR | (18 << 8)
|
||||
IOERR_SHMSIZE = IOERR | (19 << 8)
|
||||
IOERR_SHMLOCK = IOERR | (20 << 8)
|
||||
IOERR_SHMMAP = IOERR | (21 << 8)
|
||||
IOERR_SEEK = IOERR | (22 << 8)
|
||||
IOERR_DELETE_NOENT = IOERR | (23 << 8)
|
||||
IOERR_MMAP = IOERR | (24 << 8)
|
||||
IOERR_GETTEMPPATH = IOERR | (25 << 8)
|
||||
IOERR_CONVPATH = IOERR | (26 << 8)
|
||||
IOERR_VNODE = IOERR | (27 << 8)
|
||||
IOERR_AUTH = IOERR | (28 << 8)
|
||||
IOERR_BEGIN_ATOMIC = IOERR | (29 << 8)
|
||||
IOERR_COMMIT_ATOMIC = IOERR | (30 << 8)
|
||||
IOERR_ROLLBACK_ATOMIC = IOERR | (31 << 8)
|
||||
IOERR_DATA = IOERR | (32 << 8)
|
||||
IOERR_CORRUPTFS = IOERR | (33 << 8)
|
||||
LOCKED_SHAREDCACHE = LOCKED | (1 << 8)
|
||||
LOCKED_VTAB = LOCKED | (2 << 8)
|
||||
BUSY_RECOVERY = BUSY | (1 << 8)
|
||||
BUSY_SNAPSHOT = BUSY | (2 << 8)
|
||||
BUSY_TIMEOUT = BUSY | (3 << 8)
|
||||
CANTOPEN_NOTEMPDIR = CANTOPEN | (1 << 8)
|
||||
CANTOPEN_ISDIR = CANTOPEN | (2 << 8)
|
||||
CANTOPEN_FULLPATH = CANTOPEN | (3 << 8)
|
||||
CANTOPEN_CONVPATH = CANTOPEN | (4 << 8)
|
||||
CANTOPEN_DIRTYWAL = CANTOPEN | (5 << 8) /* Not Used */
|
||||
CANTOPEN_SYMLINK = CANTOPEN | (6 << 8)
|
||||
CORRUPT_VTAB = CORRUPT | (1 << 8)
|
||||
CORRUPT_SEQUENCE = CORRUPT | (2 << 8)
|
||||
CORRUPT_INDEX = CORRUPT | (3 << 8)
|
||||
READONLY_RECOVERY = READONLY | (1 << 8)
|
||||
READONLY_CANTLOCK = READONLY | (2 << 8)
|
||||
READONLY_ROLLBACK = READONLY | (3 << 8)
|
||||
READONLY_DBMOVED = READONLY | (4 << 8)
|
||||
READONLY_CANTINIT = READONLY | (5 << 8)
|
||||
READONLY_DIRECTORY = READONLY | (6 << 8)
|
||||
ABORT_ROLLBACK = ABORT | (2 << 8)
|
||||
CONSTRAINT_CHECK = CONSTRAINT | (1 << 8)
|
||||
CONSTRAINT_COMMITHOOK = CONSTRAINT | (2 << 8)
|
||||
CONSTRAINT_FOREIGNKEY = CONSTRAINT | (3 << 8)
|
||||
CONSTRAINT_FUNCTION = CONSTRAINT | (4 << 8)
|
||||
CONSTRAINT_NOTNULL = CONSTRAINT | (5 << 8)
|
||||
CONSTRAINT_PRIMARYKEY = CONSTRAINT | (6 << 8)
|
||||
CONSTRAINT_TRIGGER = CONSTRAINT | (7 << 8)
|
||||
CONSTRAINT_UNIQUE = CONSTRAINT | (8 << 8)
|
||||
CONSTRAINT_VTAB = CONSTRAINT | (9 << 8)
|
||||
CONSTRAINT_ROWID = CONSTRAINT | (10 << 8)
|
||||
CONSTRAINT_PINNED = CONSTRAINT | (11 << 8)
|
||||
CONSTRAINT_DATATYPE = CONSTRAINT | (12 << 8)
|
||||
NOTICE_RECOVER_WAL = NOTICE | (1 << 8)
|
||||
NOTICE_RECOVER_ROLLBACK = NOTICE | (2 << 8)
|
||||
NOTICE_RBU = NOTICE | (3 << 8)
|
||||
WARNING_AUTOINDEX = WARNING | (1 << 8)
|
||||
AUTH_USER = AUTH | (1 << 8)
|
||||
|
||||
OK_LOAD_PERMANENTLY = OK | (1 << 8)
|
||||
OK_SYMLINK = OK | (2 << 8) /* internal use only */
|
||||
)
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
OffsetErr = ErrorString("sqlite3: invalid offset")
|
||||
TailErr = ErrorString("sqlite3: multiple statements")
|
||||
IsolationErr = ErrorString("sqlite3: unsupported isolation level")
|
||||
NoVFSErr = ErrorString("sqlite3: no such vfs: ")
|
||||
)
|
||||
|
||||
func AssertErr() ErrorString {
|
||||
@@ -40,3 +41,75 @@ func Finalizer[T any](skip int) func(*T) {
|
||||
}
|
||||
return func(*T) { panic(ErrorString(msg)) }
|
||||
}
|
||||
|
||||
func ErrorCodeString(rc uint32) string {
|
||||
switch rc {
|
||||
case ABORT_ROLLBACK:
|
||||
return "sqlite3: abort due to ROLLBACK"
|
||||
case ROW:
|
||||
return "sqlite3: another row available"
|
||||
case DONE:
|
||||
return "sqlite3: no more rows available"
|
||||
}
|
||||
switch rc & 0xff {
|
||||
case OK:
|
||||
return "sqlite3: not an error"
|
||||
case ERROR:
|
||||
return "sqlite3: SQL logic error"
|
||||
case INTERNAL:
|
||||
break
|
||||
case PERM:
|
||||
return "sqlite3: access permission denied"
|
||||
case ABORT:
|
||||
return "sqlite3: query aborted"
|
||||
case BUSY:
|
||||
return "sqlite3: database is locked"
|
||||
case LOCKED:
|
||||
return "sqlite3: database table is locked"
|
||||
case NOMEM:
|
||||
return "sqlite3: out of memory"
|
||||
case READONLY:
|
||||
return "sqlite3: attempt to write a readonly database"
|
||||
case INTERRUPT:
|
||||
return "sqlite3: interrupted"
|
||||
case IOERR:
|
||||
return "sqlite3: disk I/O error"
|
||||
case CORRUPT:
|
||||
return "sqlite3: database disk image is malformed"
|
||||
case NOTFOUND:
|
||||
return "sqlite3: unknown operation"
|
||||
case FULL:
|
||||
return "sqlite3: database or disk is full"
|
||||
case CANTOPEN:
|
||||
return "sqlite3: unable to open database file"
|
||||
case PROTOCOL:
|
||||
return "sqlite3: locking protocol"
|
||||
case FORMAT:
|
||||
break
|
||||
case SCHEMA:
|
||||
return "sqlite3: database schema has changed"
|
||||
case TOOBIG:
|
||||
return "sqlite3: string or blob too big"
|
||||
case CONSTRAINT:
|
||||
return "sqlite3: constraint failed"
|
||||
case MISMATCH:
|
||||
return "sqlite3: datatype mismatch"
|
||||
case MISUSE:
|
||||
return "sqlite3: bad parameter or other API misuse"
|
||||
case NOLFS:
|
||||
break
|
||||
case AUTH:
|
||||
return "sqlite3: authorization denied"
|
||||
case EMPTY:
|
||||
break
|
||||
case RANGE:
|
||||
return "sqlite3: column index out of range"
|
||||
case NOTADB:
|
||||
return "sqlite3: file is not a database"
|
||||
case NOTICE:
|
||||
return "sqlite3: notification message"
|
||||
case WARNING:
|
||||
return "sqlite3: warning message"
|
||||
}
|
||||
return "sqlite3: unknown error"
|
||||
}
|
||||
|
||||
@@ -10,72 +10,119 @@ import (
|
||||
type i32 interface{ ~int32 | ~uint32 }
|
||||
type i64 interface{ ~int64 | ~uint64 }
|
||||
|
||||
func RegisterFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0) TR) {
|
||||
type funcVI[T0 i32] func(context.Context, api.Module, T0)
|
||||
|
||||
func (fn funcVI[T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]))
|
||||
}
|
||||
|
||||
func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
|
||||
}),
|
||||
WithGoModuleFunction(funcVI[T0](fn),
|
||||
[]api.ValueType{api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
|
||||
|
||||
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
|
||||
}
|
||||
|
||||
func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2)) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcVIII[T0, T1, T2](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, nil).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
type funcII[TR, T0 i32] func(context.Context, api.Module, T0) TR
|
||||
|
||||
func (fn funcII[TR, T0]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0])))
|
||||
}
|
||||
|
||||
func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(funcII[TR, T0](fn),
|
||||
[]api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
func RegisterFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) TR) {
|
||||
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
|
||||
|
||||
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
|
||||
}
|
||||
|
||||
func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
|
||||
}),
|
||||
WithGoModuleFunction(funcIII[TR, T0, T1](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
func RegisterFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2) TR) {
|
||||
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
|
||||
|
||||
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
|
||||
}
|
||||
|
||||
func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
|
||||
}),
|
||||
WithGoModuleFunction(funcIIII[TR, T0, T1, T2](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
func RegisterFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) TR) {
|
||||
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
|
||||
|
||||
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
|
||||
}
|
||||
|
||||
func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
|
||||
}),
|
||||
WithGoModuleFunction(funcIIIII[TR, T0, T1, T2, T3](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
func RegisterFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3, _ T4) TR) {
|
||||
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
|
||||
|
||||
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
|
||||
}
|
||||
|
||||
func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3, T4) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
|
||||
}),
|
||||
WithGoModuleFunction(funcIIIIII[TR, T0, T1, T2, T3, T4](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
func RegisterFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) TR) {
|
||||
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
|
||||
|
||||
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
|
||||
}
|
||||
|
||||
func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1, T2, T3) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
|
||||
}),
|
||||
WithGoModuleFunction(funcIIIIJ[TR, T0, T1, T2, T3](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
func RegisterFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) TR) {
|
||||
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
|
||||
|
||||
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
|
||||
}
|
||||
|
||||
func ExportFuncIIJ[TR, T0 i32, T1 i64](mod wazero.HostModuleBuilder, name string, fn func(context.Context, api.Module, T0, T1) TR) {
|
||||
mod.NewFunctionBuilder().
|
||||
WithGoModuleFunction(api.GoModuleFunc(
|
||||
func(ctx context.Context, mod api.Module, stack []uint64) {
|
||||
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
|
||||
}),
|
||||
WithGoModuleFunction(funcIIJ[TR, T0, T1](fn),
|
||||
[]api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}).
|
||||
Export(name)
|
||||
}
|
||||
|
||||
74
internal/util/handle.go
Normal file
74
internal/util/handle.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type handleKey struct{}
|
||||
type handleState struct {
|
||||
handles []any
|
||||
empty int
|
||||
}
|
||||
|
||||
func NewContext(ctx context.Context) (context.Context, io.Closer) {
|
||||
state := new(handleState)
|
||||
return context.WithValue(ctx, handleKey{}, state), state
|
||||
}
|
||||
|
||||
func (s *handleState) Close() (err error) {
|
||||
for _, h := range s.handles {
|
||||
if c, ok := h.(io.Closer); ok {
|
||||
if e := c.Close(); err == nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
}
|
||||
s.handles = nil
|
||||
s.empty = 0
|
||||
return err
|
||||
}
|
||||
|
||||
func GetHandle(ctx context.Context, id uint32) any {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
return s.handles[^id]
|
||||
}
|
||||
|
||||
func DelHandle(ctx context.Context, id uint32) error {
|
||||
if id == 0 {
|
||||
return nil
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
a := s.handles[^id]
|
||||
s.handles[^id] = nil
|
||||
s.empty++
|
||||
if c, ok := a.(io.Closer); ok {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddHandle(ctx context.Context, a any) (id uint32) {
|
||||
if a == nil {
|
||||
panic(NilErr)
|
||||
}
|
||||
s := ctx.Value(handleKey{}).(*handleState)
|
||||
|
||||
// Find an empty slot.
|
||||
if s.empty > cap(s.handles)-len(s.handles) {
|
||||
for id, h := range s.handles {
|
||||
if h == nil {
|
||||
s.empty--
|
||||
s.handles[id] = a
|
||||
return ^uint32(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new slot.
|
||||
s.handles = append(s.handles, a)
|
||||
return -uint32(len(s.handles))
|
||||
}
|
||||
@@ -3,88 +3,90 @@ package util
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/tetratelabs/wazero/experimental/wazerotest"
|
||||
)
|
||||
|
||||
func TestView_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
View(mock, 0, 8)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestView_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
View(mock, 126, 8)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
View(mock, wazerotest.PageSize-2, 8)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestView_overflow(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
View(mock, 1, math.MaxInt64)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint32_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint32(mock, 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint32_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
ReadUint32(mock, 126)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint32(mock, wazerotest.PageSize-2)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint64_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint64(mock, 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadUint64_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
ReadUint64(mock, 126)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadUint64(mock, wazerotest.PageSize-2)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint32_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint32(mock, 0, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint32_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
WriteUint32(mock, 126, 1)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint32(mock, wazerotest.PageSize-2, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint64_nil(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint64(mock, 0, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestWriteUint64_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
WriteUint64(mock, 126, 1)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
WriteUint64(mock, wazerotest.PageSize-2, 1)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestReadString_range(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
mock := NewMockModule(128)
|
||||
ReadString(mock, 130, math.MaxUint32)
|
||||
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
|
||||
ReadString(mock, wazerotest.PageSize+2, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
@@ -1,153 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func NewMockModule(size uint32) api.Module {
|
||||
mem := make(mockMemory, size)
|
||||
return mockModule{&mem}
|
||||
}
|
||||
|
||||
type mockModule struct {
|
||||
memory api.Memory
|
||||
}
|
||||
|
||||
func (m mockModule) Memory() api.Memory { return m.memory }
|
||||
func (m mockModule) String() string { return "mockModule" }
|
||||
func (m mockModule) Name() string { return "mockModule" }
|
||||
|
||||
func (m mockModule) ExportedGlobal(name string) api.Global { return nil }
|
||||
func (m mockModule) ExportedMemory(name string) api.Memory { return nil }
|
||||
func (m mockModule) ExportedFunction(name string) api.Function { return nil }
|
||||
func (m mockModule) ExportedMemoryDefinitions() map[string]api.MemoryDefinition { return nil }
|
||||
func (m mockModule) ExportedFunctionDefinitions() map[string]api.FunctionDefinition { return nil }
|
||||
func (m mockModule) CloseWithExitCode(ctx context.Context, exitCode uint32) error { return nil }
|
||||
func (m mockModule) Close(context.Context) error { return nil }
|
||||
|
||||
type mockMemory []byte
|
||||
|
||||
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
|
||||
|
||||
func (m mockMemory) Size() uint32 { return uint32(len(m)) }
|
||||
|
||||
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
|
||||
if offset >= m.Size() {
|
||||
return 0, false
|
||||
}
|
||||
return m[offset], true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint16(m[offset : offset+2]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint32(m[offset : offset+4]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
|
||||
v, ok := m.ReadUint32Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float32frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return 0, false
|
||||
}
|
||||
return binary.LittleEndian.Uint64(m[offset : offset+8]), true
|
||||
}
|
||||
|
||||
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
|
||||
v, ok := m.ReadUint64Le(offset)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float64frombits(v), true
|
||||
}
|
||||
|
||||
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
|
||||
if !m.hasSize(offset, byteCount) {
|
||||
return nil, false
|
||||
}
|
||||
return m[offset : offset+byteCount : offset+byteCount], true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
|
||||
if offset >= m.Size() {
|
||||
return false
|
||||
}
|
||||
m[offset] = v
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
|
||||
if !m.hasSize(offset, 2) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint16(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
|
||||
if !m.hasSize(offset, 4) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint32(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
|
||||
return m.WriteUint32Le(offset, math.Float32bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
|
||||
if !m.hasSize(offset, 8) {
|
||||
return false
|
||||
}
|
||||
binary.LittleEndian.PutUint64(m[offset:], v)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
|
||||
return m.WriteUint64Le(offset, math.Float64bits(v))
|
||||
}
|
||||
|
||||
func (m mockMemory) Write(offset uint32, val []byte) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m mockMemory) WriteString(offset uint32, val string) bool {
|
||||
if !m.hasSize(offset, uint32(len(val))) {
|
||||
return false
|
||||
}
|
||||
copy(m[offset:], val)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
|
||||
prev := (len(*m) + 65535) / 65536
|
||||
*m = append(*m, make([]byte, 65536*delta)...)
|
||||
return uint32(prev), true
|
||||
}
|
||||
|
||||
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
|
||||
return uint64(offset)+uint64(byteCount) <= uint64(len(m))
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_mockMemory_byte(t *testing.T) {
|
||||
const want byte = 98
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadByte(128)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteByte(128, 0)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteByte(0, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().ReadByte(0)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_uint16(t *testing.T) {
|
||||
const want uint16 = 9876
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadUint16Le(128)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteUint16Le(128, 0)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteUint16Le(0, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().ReadUint16Le(0)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_uint32(t *testing.T) {
|
||||
const want uint32 = 987654321
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadUint32Le(128)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteUint32Le(128, 0)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteUint32Le(0, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().ReadUint32Le(0)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_uint64(t *testing.T) {
|
||||
const want uint64 = 9876543210
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadUint64Le(128)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteUint64Le(128, 0)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteUint64Le(0, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().ReadUint64Le(0)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_float32(t *testing.T) {
|
||||
const want float32 = math.Pi
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadFloat32Le(128)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteFloat32Le(128, 0)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteFloat32Le(0, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().ReadFloat32Le(0)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %f, want %f", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_float64(t *testing.T) {
|
||||
const want float64 = math.Pi
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadFloat64Le(128)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteFloat64Le(128, 0)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteFloat64Le(0, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().ReadFloat64Le(0)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got %f, want %f", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_bytes(t *testing.T) {
|
||||
const want string = "\xca\xfe\xba\xbe"
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().Read(128, uint32(len(want)))
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().Write(128, []byte(want))
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteString(128, want)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
ok = mock.Memory().Write(0, []byte(want))
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().Read(0, uint32(len(want)))
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
ok = mock.Memory().WriteString(64, want)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
|
||||
got, ok = mock.Memory().Read(64, uint32(len(want)))
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mockMemory_grow(t *testing.T) {
|
||||
mock := NewMockModule(128)
|
||||
|
||||
_, ok := mock.Memory().ReadByte(65536)
|
||||
if ok {
|
||||
t.Error("want error")
|
||||
}
|
||||
|
||||
got, ok := mock.Memory().Grow(1)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
if got != 1 {
|
||||
t.Errorf("got %d, want 1", got)
|
||||
}
|
||||
|
||||
_, ok = mock.Memory().ReadByte(65536)
|
||||
if !ok {
|
||||
t.Error("want ok")
|
||||
}
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
package vfs
|
||||
|
||||
const (
|
||||
_MAX_PATHNAME = 512
|
||||
_DEFAULT_SECTOR_SIZE = 4096
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/rescode.html
|
||||
type _ErrorCode uint32
|
||||
|
||||
const (
|
||||
_OK _ErrorCode = 0 /* Successful result */
|
||||
_PERM _ErrorCode = 3 /* Access permission denied */
|
||||
_BUSY _ErrorCode = 5 /* The database file is locked */
|
||||
_IOERR _ErrorCode = 10 /* Some kind of disk I/O error occurred */
|
||||
_NOTFOUND _ErrorCode = 12 /* Unknown opcode in sqlite3_file_control() */
|
||||
_CANTOPEN _ErrorCode = 14 /* Unable to open the database file */
|
||||
|
||||
_IOERR_READ = _IOERR | (1 << 8)
|
||||
_IOERR_SHORT_READ = _IOERR | (2 << 8)
|
||||
_IOERR_WRITE = _IOERR | (3 << 8)
|
||||
_IOERR_FSYNC = _IOERR | (4 << 8)
|
||||
_IOERR_DIR_FSYNC = _IOERR | (5 << 8)
|
||||
_IOERR_TRUNCATE = _IOERR | (6 << 8)
|
||||
_IOERR_FSTAT = _IOERR | (7 << 8)
|
||||
_IOERR_UNLOCK = _IOERR | (8 << 8)
|
||||
_IOERR_RDLOCK = _IOERR | (9 << 8)
|
||||
_IOERR_DELETE = _IOERR | (10 << 8)
|
||||
_IOERR_BLOCKED = _IOERR | (11 << 8)
|
||||
_IOERR_NOMEM = _IOERR | (12 << 8)
|
||||
_IOERR_ACCESS = _IOERR | (13 << 8)
|
||||
_IOERR_CHECKRESERVEDLOCK = _IOERR | (14 << 8)
|
||||
_IOERR_LOCK = _IOERR | (15 << 8)
|
||||
_IOERR_CLOSE = _IOERR | (16 << 8)
|
||||
_IOERR_DIR_CLOSE = _IOERR | (17 << 8)
|
||||
_IOERR_SHMOPEN = _IOERR | (18 << 8)
|
||||
_IOERR_SHMSIZE = _IOERR | (19 << 8)
|
||||
_IOERR_SHMLOCK = _IOERR | (20 << 8)
|
||||
_IOERR_SHMMAP = _IOERR | (21 << 8)
|
||||
_IOERR_SEEK = _IOERR | (22 << 8)
|
||||
_IOERR_DELETE_NOENT = _IOERR | (23 << 8)
|
||||
_IOERR_MMAP = _IOERR | (24 << 8)
|
||||
_IOERR_GETTEMPPATH = _IOERR | (25 << 8)
|
||||
_IOERR_CONVPATH = _IOERR | (26 << 8)
|
||||
_IOERR_VNODE = _IOERR | (27 << 8)
|
||||
_IOERR_AUTH = _IOERR | (28 << 8)
|
||||
_IOERR_BEGIN_ATOMIC = _IOERR | (29 << 8)
|
||||
_IOERR_COMMIT_ATOMIC = _IOERR | (30 << 8)
|
||||
_IOERR_ROLLBACK_ATOMIC = _IOERR | (31 << 8)
|
||||
_IOERR_DATA = _IOERR | (32 << 8)
|
||||
_IOERR_CORRUPTFS = _IOERR | (33 << 8)
|
||||
_CANTOPEN_NOTEMPDIR = _CANTOPEN | (1 << 8)
|
||||
_CANTOPEN_ISDIR = _CANTOPEN | (2 << 8)
|
||||
_CANTOPEN_FULLPATH = _CANTOPEN | (3 << 8)
|
||||
_CANTOPEN_CONVPATH = _CANTOPEN | (4 << 8)
|
||||
_CANTOPEN_DIRTYWAL = _CANTOPEN | (5 << 8) /* Not Used */
|
||||
_CANTOPEN_SYMLINK = _CANTOPEN | (6 << 8)
|
||||
_OK_SYMLINK = _OK | (2 << 8) /* internal use only */
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
|
||||
type _OpenFlag uint32
|
||||
|
||||
const (
|
||||
_OPEN_READONLY _OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_READWRITE _OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_CREATE _OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_DELETEONCLOSE _OpenFlag = 0x00000008 /* VFS only */
|
||||
_OPEN_EXCLUSIVE _OpenFlag = 0x00000010 /* VFS only */
|
||||
_OPEN_AUTOPROXY _OpenFlag = 0x00000020 /* VFS only */
|
||||
_OPEN_URI _OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_MEMORY _OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_MAIN_DB _OpenFlag = 0x00000100 /* VFS only */
|
||||
_OPEN_TEMP_DB _OpenFlag = 0x00000200 /* VFS only */
|
||||
_OPEN_TRANSIENT_DB _OpenFlag = 0x00000400 /* VFS only */
|
||||
_OPEN_MAIN_JOURNAL _OpenFlag = 0x00000800 /* VFS only */
|
||||
_OPEN_TEMP_JOURNAL _OpenFlag = 0x00001000 /* VFS only */
|
||||
_OPEN_SUBJOURNAL _OpenFlag = 0x00002000 /* VFS only */
|
||||
_OPEN_SUPER_JOURNAL _OpenFlag = 0x00004000 /* VFS only */
|
||||
_OPEN_NOMUTEX _OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_FULLMUTEX _OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_SHAREDCACHE _OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_PRIVATECACHE _OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_WAL _OpenFlag = 0x00080000 /* VFS only */
|
||||
_OPEN_NOFOLLOW _OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
|
||||
_OPEN_EXRESCODE _OpenFlag = 0x02000000 /* Extended result codes */
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_access_exists.html
|
||||
type _AccessFlag uint32
|
||||
|
||||
const (
|
||||
_ACCESS_EXISTS _AccessFlag = 0
|
||||
_ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
|
||||
_ACCESS_READ _AccessFlag = 2 /* Unused */
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_sync_dataonly.html
|
||||
type _SyncFlag uint32
|
||||
|
||||
const (
|
||||
_SYNC_NORMAL _SyncFlag = 0x00002
|
||||
_SYNC_FULL _SyncFlag = 0x00003
|
||||
_SYNC_DATAONLY _SyncFlag = 0x00010
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_lock_exclusive.html
|
||||
type _LockLevel uint32
|
||||
|
||||
const (
|
||||
// No locks are held on the database.
|
||||
// The database may be neither read nor written.
|
||||
// Any internally cached data is considered suspect and subject to
|
||||
// verification against the database file before being used.
|
||||
// Other processes can read or write the database as their own locking
|
||||
// states permit.
|
||||
// This is the default state.
|
||||
_LOCK_NONE _LockLevel = 0 /* xUnlock() only */
|
||||
|
||||
// The database may be read but not written.
|
||||
// Any number of processes can hold SHARED locks at the same time,
|
||||
// hence there can be many simultaneous readers.
|
||||
// But no other thread or process is allowed to write to the database file
|
||||
// while one or more SHARED locks are active.
|
||||
_LOCK_SHARED _LockLevel = 1 /* xLock() or xUnlock() */
|
||||
|
||||
// A RESERVED lock means that the process is planning on writing to the
|
||||
// database file at some point in the future but that it is currently just
|
||||
// reading from the file.
|
||||
// Only a single RESERVED lock may be active at one time,
|
||||
// though multiple SHARED locks can coexist with a single RESERVED lock.
|
||||
// RESERVED differs from PENDING in that new SHARED locks can be acquired
|
||||
// while there is a RESERVED lock.
|
||||
_LOCK_RESERVED _LockLevel = 2 /* xLock() only */
|
||||
|
||||
// A PENDING lock means that the process holding the lock wants to write to
|
||||
// the database as soon as possible and is just waiting on all current
|
||||
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
|
||||
// No new SHARED locks are permitted against the database if a PENDING lock
|
||||
// is active, though existing SHARED locks are allowed to continue.
|
||||
_LOCK_PENDING _LockLevel = 3 /* internal use only */
|
||||
|
||||
// An EXCLUSIVE lock is needed in order to write to the database file.
|
||||
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
|
||||
// kind are allowed to coexist with an EXCLUSIVE lock.
|
||||
// In order to maximize concurrency, SQLite works to minimize the amount of
|
||||
// time that EXCLUSIVE locks are held.
|
||||
_LOCK_EXCLUSIVE _LockLevel = 4 /* xLock() only */
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type _FcntlOpcode uint32
|
||||
|
||||
const (
|
||||
_FCNTL_LOCKSTATE _FcntlOpcode = 1
|
||||
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
|
||||
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
|
||||
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
|
||||
_FCNTL_SIZE_HINT _FcntlOpcode = 5
|
||||
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
|
||||
_FCNTL_FILE_POINTER _FcntlOpcode = 7
|
||||
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
|
||||
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
|
||||
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
|
||||
_FCNTL_OVERWRITE _FcntlOpcode = 11
|
||||
_FCNTL_VFSNAME _FcntlOpcode = 12
|
||||
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
|
||||
_FCNTL_PRAGMA _FcntlOpcode = 14
|
||||
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
|
||||
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
|
||||
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
|
||||
_FCNTL_TRACE _FcntlOpcode = 19
|
||||
_FCNTL_HAS_MOVED _FcntlOpcode = 20
|
||||
_FCNTL_SYNC _FcntlOpcode = 21
|
||||
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
|
||||
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
|
||||
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
|
||||
_FCNTL_ZIPVFS _FcntlOpcode = 25
|
||||
_FCNTL_RBU _FcntlOpcode = 26
|
||||
_FCNTL_VFS_POINTER _FcntlOpcode = 27
|
||||
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
|
||||
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
|
||||
_FCNTL_PDB _FcntlOpcode = 30
|
||||
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
|
||||
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
|
||||
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
|
||||
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
|
||||
_FCNTL_DATA_VERSION _FcntlOpcode = 35
|
||||
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
|
||||
_FCNTL_CKPT_DONE _FcntlOpcode = 37
|
||||
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
|
||||
_FCNTL_CKPT_START _FcntlOpcode = 39
|
||||
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
|
||||
_FCNTL_CKSM_FILE _FcntlOpcode = 41
|
||||
_FCNTL_RESET_CACHE _FcntlOpcode = 42
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_iocap_atomic.html
|
||||
type _DeviceCharacteristic uint32
|
||||
|
||||
const (
|
||||
_IOCAP_ATOMIC _DeviceCharacteristic = 0x00000001
|
||||
_IOCAP_ATOMIC512 _DeviceCharacteristic = 0x00000002
|
||||
_IOCAP_ATOMIC1K _DeviceCharacteristic = 0x00000004
|
||||
_IOCAP_ATOMIC2K _DeviceCharacteristic = 0x00000008
|
||||
_IOCAP_ATOMIC4K _DeviceCharacteristic = 0x00000010
|
||||
_IOCAP_ATOMIC8K _DeviceCharacteristic = 0x00000020
|
||||
_IOCAP_ATOMIC16K _DeviceCharacteristic = 0x00000040
|
||||
_IOCAP_ATOMIC32K _DeviceCharacteristic = 0x00000080
|
||||
_IOCAP_ATOMIC64K _DeviceCharacteristic = 0x00000100
|
||||
_IOCAP_SAFE_APPEND _DeviceCharacteristic = 0x00000200
|
||||
_IOCAP_SEQUENTIAL _DeviceCharacteristic = 0x00000400
|
||||
_IOCAP_UNDELETABLE_WHEN_OPEN _DeviceCharacteristic = 0x00000800
|
||||
_IOCAP_POWERSAFE_OVERWRITE _DeviceCharacteristic = 0x00001000
|
||||
_IOCAP_IMMUTABLE _DeviceCharacteristic = 0x00002000
|
||||
_IOCAP_BATCH_ATOMIC _DeviceCharacteristic = 0x00004000
|
||||
)
|
||||
22
internal/vfs/tests/mptest/testdata/main.c
vendored
22
internal/vfs/tests/mptest/testdata/main.c
vendored
@@ -1,22 +0,0 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.c"
|
||||
//
|
||||
#include "os.c"
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
|
||||
|
||||
int sqlite3_os_init() {
|
||||
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
|
||||
}
|
||||
|
||||
__attribute__((constructor)) void premain() { sqlite3_initialize(); }
|
||||
|
||||
static int dont_unlink(const char *pathname) { return 0; }
|
||||
#define sqlite3_enable_load_extension(...)
|
||||
#define sqlite3_trace(...)
|
||||
#define unlink dont_unlink
|
||||
#undef UNUSED_PARAMETER
|
||||
#include "mptest.c"
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:37ceeed293b9f09e9770b40eda3f625447f5a3a74208709886d4411d12f93414
|
||||
size 1486113
|
||||
16
internal/vfs/tests/speedtest1/testdata/main.c
vendored
16
internal/vfs/tests/speedtest1/testdata/main.c
vendored
@@ -1,16 +0,0 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "sqlite3.c"
|
||||
//
|
||||
#include "os.c"
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
|
||||
|
||||
int sqlite3_os_init() {
|
||||
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
|
||||
}
|
||||
|
||||
#define randomFunc(args...) randomFunc2(args)
|
||||
#include "speedtest1.c"
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8167119c344a68217b0301e2e8c288f2e75611d296d7822f841b65911da0275c
|
||||
size 1520569
|
||||
@@ -1,390 +0,0 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/julianday"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func Export(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
|
||||
util.RegisterFuncIIJ(env, "os_localtime", vfsLocaltime)
|
||||
util.RegisterFuncIIII(env, "os_randomness", vfsRandomness)
|
||||
util.RegisterFuncIII(env, "os_sleep", vfsSleep)
|
||||
util.RegisterFuncIII(env, "os_current_time", vfsCurrentTime)
|
||||
util.RegisterFuncIII(env, "os_current_time_64", vfsCurrentTime64)
|
||||
util.RegisterFuncIIIII(env, "os_full_pathname", vfsFullPathname)
|
||||
util.RegisterFuncIIII(env, "os_delete", vfsDelete)
|
||||
util.RegisterFuncIIIII(env, "os_access", vfsAccess)
|
||||
util.RegisterFuncIIIIII(env, "os_open", vfsOpen)
|
||||
util.RegisterFuncII(env, "os_close", vfsClose)
|
||||
util.RegisterFuncIIIIJ(env, "os_read", vfsRead)
|
||||
util.RegisterFuncIIIIJ(env, "os_write", vfsWrite)
|
||||
util.RegisterFuncIIJ(env, "os_truncate", vfsTruncate)
|
||||
util.RegisterFuncIII(env, "os_sync", vfsSync)
|
||||
util.RegisterFuncIII(env, "os_file_size", vfsFileSize)
|
||||
util.RegisterFuncIIII(env, "os_file_control", vfsFileControl)
|
||||
util.RegisterFuncII(env, "os_sector_size", vfsSectorSize)
|
||||
util.RegisterFuncII(env, "os_device_characteristics", vfsDeviceCharacteristics)
|
||||
util.RegisterFuncIII(env, "os_lock", vfsLock)
|
||||
util.RegisterFuncIII(env, "os_unlock", vfsUnlock)
|
||||
util.RegisterFuncIII(env, "os_check_reserved_lock", vfsCheckReservedLock)
|
||||
return env
|
||||
}
|
||||
|
||||
type vfsKey struct{}
|
||||
type vfsState struct {
|
||||
files []vfsFile
|
||||
}
|
||||
|
||||
func Context(ctx context.Context) (context.Context, io.Closer) {
|
||||
vfs := &vfsState{}
|
||||
return context.WithValue(ctx, vfsKey{}, vfs), vfs
|
||||
}
|
||||
|
||||
func (vfs *vfsState) Close() error {
|
||||
for _, f := range vfs.files {
|
||||
if f.File != nil {
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
vfs.files = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode {
|
||||
tm := time.Unix(t, 0)
|
||||
var isdst int
|
||||
if tm.IsDST() {
|
||||
isdst = 1
|
||||
}
|
||||
|
||||
const size = 32 / 8
|
||||
// https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html
|
||||
util.WriteUint32(mod, pTm+0*size, uint32(tm.Second()))
|
||||
util.WriteUint32(mod, pTm+1*size, uint32(tm.Minute()))
|
||||
util.WriteUint32(mod, pTm+2*size, uint32(tm.Hour()))
|
||||
util.WriteUint32(mod, pTm+3*size, uint32(tm.Day()))
|
||||
util.WriteUint32(mod, pTm+4*size, uint32(tm.Month()-time.January))
|
||||
util.WriteUint32(mod, pTm+5*size, uint32(tm.Year()-1900))
|
||||
util.WriteUint32(mod, pTm+6*size, uint32(tm.Weekday()-time.Sunday))
|
||||
util.WriteUint32(mod, pTm+7*size, uint32(tm.YearDay()-1))
|
||||
util.WriteUint32(mod, pTm+8*size, uint32(isdst))
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 {
|
||||
mem := util.View(mod, zByte, uint64(nByte))
|
||||
n, _ := rand.Reader.Read(mem)
|
||||
return uint32(n)
|
||||
}
|
||||
|
||||
func vfsSleep(ctx context.Context, mod api.Module, pVfs, nMicro uint32) _ErrorCode {
|
||||
time.Sleep(time.Duration(nMicro) * time.Microsecond)
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) _ErrorCode {
|
||||
day := julianday.Float(time.Now())
|
||||
util.WriteFloat64(mod, prNow, day)
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode {
|
||||
day, nsec := julianday.Date(time.Now())
|
||||
msec := day*86_400_000 + nsec/1_000_000
|
||||
util.WriteUint64(mod, piNow, uint64(msec))
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) _ErrorCode {
|
||||
rel := util.ReadString(mod, zRelative, _MAX_PATHNAME)
|
||||
abs, err := filepath.Abs(rel)
|
||||
if err != nil {
|
||||
return _CANTOPEN_FULLPATH
|
||||
}
|
||||
|
||||
size := uint64(len(abs) + 1)
|
||||
if size > uint64(nFull) {
|
||||
return _CANTOPEN_FULLPATH
|
||||
}
|
||||
mem := util.View(mod, zFull, size)
|
||||
mem[len(abs)] = 0
|
||||
copy(mem, abs)
|
||||
|
||||
if fi, err := os.Lstat(abs); err == nil {
|
||||
if fi.Mode()&fs.ModeSymlink != 0 {
|
||||
return _OK_SYMLINK
|
||||
}
|
||||
return _OK
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
return _OK
|
||||
}
|
||||
return _CANTOPEN_FULLPATH
|
||||
}
|
||||
|
||||
func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode {
|
||||
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
|
||||
err := os.Remove(path)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return _IOERR_DELETE_NOENT
|
||||
}
|
||||
if err != nil {
|
||||
return _IOERR_DELETE
|
||||
}
|
||||
if runtime.GOOS != "windows" && syncDir != 0 {
|
||||
f, err := os.Open(filepath.Dir(path))
|
||||
if err != nil {
|
||||
return _OK
|
||||
}
|
||||
defer f.Close()
|
||||
err = osSync(f, false, false)
|
||||
if err != nil {
|
||||
return _IOERR_DIR_FSYNC
|
||||
}
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) _ErrorCode {
|
||||
path := util.ReadString(mod, zPath, _MAX_PATHNAME)
|
||||
err := osAccess(path, flags)
|
||||
|
||||
var res uint32
|
||||
var rc _ErrorCode
|
||||
if flags == _ACCESS_EXISTS {
|
||||
switch {
|
||||
case err == nil:
|
||||
res = 1
|
||||
case errors.Is(err, fs.ErrNotExist):
|
||||
res = 0
|
||||
default:
|
||||
rc = _IOERR_ACCESS
|
||||
}
|
||||
} else {
|
||||
switch {
|
||||
case err == nil:
|
||||
res = 1
|
||||
case errors.Is(err, fs.ErrPermission):
|
||||
res = 0
|
||||
default:
|
||||
rc = _IOERR_ACCESS
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteUint32(mod, pResOut, res)
|
||||
return rc
|
||||
}
|
||||
|
||||
func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags _OpenFlag, pOutFlags uint32) _ErrorCode {
|
||||
var oflags int
|
||||
if flags&_OPEN_EXCLUSIVE != 0 {
|
||||
oflags |= os.O_EXCL
|
||||
}
|
||||
if flags&_OPEN_CREATE != 0 {
|
||||
oflags |= os.O_CREATE
|
||||
}
|
||||
if flags&_OPEN_READONLY != 0 {
|
||||
oflags |= os.O_RDONLY
|
||||
}
|
||||
if flags&_OPEN_READWRITE != 0 {
|
||||
oflags |= os.O_RDWR
|
||||
}
|
||||
|
||||
var err error
|
||||
var f *os.File
|
||||
if zName == 0 {
|
||||
f, err = os.CreateTemp("", "*.db")
|
||||
} else {
|
||||
name := util.ReadString(mod, zName, _MAX_PATHNAME)
|
||||
f, err = osOpenFile(name, oflags, 0666)
|
||||
}
|
||||
if err != nil {
|
||||
return _CANTOPEN
|
||||
}
|
||||
|
||||
if flags&_OPEN_DELETEONCLOSE != 0 {
|
||||
os.Remove(f.Name())
|
||||
}
|
||||
|
||||
file := openVFSFile(ctx, mod, pFile, f)
|
||||
file.psow = true
|
||||
file.readOnly = flags&_OPEN_READONLY != 0
|
||||
file.syncDir = runtime.GOOS != "windows" &&
|
||||
flags&(_OPEN_CREATE) != 0 &&
|
||||
flags&(_OPEN_MAIN_JOURNAL|_OPEN_SUPER_JOURNAL|_OPEN_WAL) != 0
|
||||
|
||||
if pOutFlags != 0 {
|
||||
util.WriteUint32(mod, pOutFlags, uint32(flags))
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode {
|
||||
err := closeVFSFile(ctx, mod, pFile)
|
||||
if err != nil {
|
||||
return _IOERR_CLOSE
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
|
||||
buf := util.View(mod, zBuf, uint64(iAmt))
|
||||
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
n, err := file.ReadAt(buf, iOfst)
|
||||
if n == int(iAmt) {
|
||||
return _OK
|
||||
}
|
||||
if n == 0 && err != io.EOF {
|
||||
return _IOERR_READ
|
||||
}
|
||||
for i := range buf[n:] {
|
||||
buf[n+i] = 0
|
||||
}
|
||||
return _IOERR_SHORT_READ
|
||||
}
|
||||
|
||||
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode {
|
||||
buf := util.View(mod, zBuf, uint64(iAmt))
|
||||
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
_, err := file.WriteAt(buf, iOfst)
|
||||
if err != nil {
|
||||
return _IOERR_WRITE
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode {
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
err := file.Truncate(nByte)
|
||||
if err != nil {
|
||||
return _IOERR_TRUNCATE
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) _ErrorCode {
|
||||
dataonly := (flags & _SYNC_DATAONLY) != 0
|
||||
fullsync := (flags & 0x0f) == _SYNC_FULL
|
||||
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
err := osSync(file.File, fullsync, dataonly)
|
||||
if err != nil {
|
||||
return _IOERR_FSYNC
|
||||
}
|
||||
if runtime.GOOS != "windows" && file.syncDir {
|
||||
file.syncDir = false
|
||||
f, err := os.Open(filepath.Dir(file.Name()))
|
||||
if err != nil {
|
||||
return _OK
|
||||
}
|
||||
defer f.Close()
|
||||
err = osSync(f, false, false)
|
||||
if err != nil {
|
||||
return _IOERR_DIR_FSYNC
|
||||
}
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode {
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
off, err := file.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return _IOERR_SEEK
|
||||
}
|
||||
|
||||
util.WriteUint64(mod, pSize, uint64(off))
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode {
|
||||
switch op {
|
||||
case _FCNTL_LOCKSTATE:
|
||||
util.WriteUint32(mod, pArg, uint32(getVFSFile(ctx, mod, pFile).lock))
|
||||
return _OK
|
||||
case _FCNTL_LOCK_TIMEOUT:
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
millis := file.lockTimeout.Milliseconds()
|
||||
file.lockTimeout = time.Duration(util.ReadUint32(mod, pArg)) * time.Millisecond
|
||||
util.WriteUint32(mod, pArg, uint32(millis))
|
||||
return _OK
|
||||
case _FCNTL_POWERSAFE_OVERWRITE:
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
switch util.ReadUint32(mod, pArg) {
|
||||
case 0:
|
||||
file.psow = false
|
||||
case 1:
|
||||
file.psow = true
|
||||
default:
|
||||
if file.psow {
|
||||
util.WriteUint32(mod, pArg, 1)
|
||||
} else {
|
||||
util.WriteUint32(mod, pArg, 0)
|
||||
}
|
||||
}
|
||||
case _FCNTL_SIZE_HINT:
|
||||
return vfsSizeHint(ctx, mod, pFile, pArg)
|
||||
case _FCNTL_HAS_MOVED:
|
||||
return vfsFileMoved(ctx, mod, pFile, pArg)
|
||||
}
|
||||
// Consider also implementing these opcodes (in use by SQLite):
|
||||
// _FCNTL_BUSYHANDLER
|
||||
// _FCNTL_COMMIT_PHASETWO
|
||||
// _FCNTL_PDB
|
||||
// _FCNTL_PRAGMA
|
||||
// _FCNTL_SYNC
|
||||
return _NOTFOUND
|
||||
}
|
||||
|
||||
func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 {
|
||||
return _DEFAULT_SECTOR_SIZE
|
||||
}
|
||||
|
||||
func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) _DeviceCharacteristic {
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
if file.psow {
|
||||
return _IOCAP_POWERSAFE_OVERWRITE
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func vfsSizeHint(ctx context.Context, mod api.Module, pFile, pArg uint32) _ErrorCode {
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
size := util.ReadUint64(mod, pArg)
|
||||
err := osAllocate(file.File, int64(size))
|
||||
if err != nil {
|
||||
return _IOERR_TRUNCATE
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func vfsFileMoved(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
fi, err := file.Stat()
|
||||
if err != nil {
|
||||
return _IOERR_FSTAT
|
||||
}
|
||||
pi, err := os.Stat(file.Name())
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return _IOERR_FSTAT
|
||||
}
|
||||
var res uint32
|
||||
if !os.SameFile(fi, pi) {
|
||||
res = 1
|
||||
}
|
||||
util.WriteUint32(mod, pResOut, res)
|
||||
return _OK
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
type vfsFile struct {
|
||||
*os.File
|
||||
lockTimeout time.Duration
|
||||
lock _LockLevel
|
||||
psow bool
|
||||
syncDir bool
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
func newVFSFile(vfs *vfsState, file *os.File) uint32 {
|
||||
// Find an empty slot.
|
||||
for id, f := range vfs.files {
|
||||
if f.File == nil {
|
||||
vfs.files[id] = vfsFile{File: file}
|
||||
return uint32(id)
|
||||
}
|
||||
}
|
||||
|
||||
// Add a new slot.
|
||||
vfs.files = append(vfs.files, vfsFile{File: file})
|
||||
return uint32(len(vfs.files) - 1)
|
||||
}
|
||||
|
||||
func getVFSFile(ctx context.Context, mod api.Module, pFile uint32) *vfsFile {
|
||||
vfs := ctx.Value(vfsKey{}).(*vfsState)
|
||||
id := util.ReadUint32(mod, pFile+4)
|
||||
return &vfs.files[id]
|
||||
}
|
||||
|
||||
func openVFSFile(ctx context.Context, mod api.Module, pFile uint32, file *os.File) *vfsFile {
|
||||
vfs := ctx.Value(vfsKey{}).(*vfsState)
|
||||
id := newVFSFile(vfs, file)
|
||||
util.WriteUint32(mod, pFile+4, id)
|
||||
return &vfs.files[id]
|
||||
}
|
||||
|
||||
func closeVFSFile(ctx context.Context, mod api.Module, pFile uint32) error {
|
||||
vfs := ctx.Value(vfsKey{}).(*vfsState)
|
||||
id := util.ReadUint32(mod, pFile+4)
|
||||
file := vfs.files[id]
|
||||
vfs.files[id] = vfsFile{}
|
||||
return file.Close()
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
const (
|
||||
_PENDING_BYTE = 0x40000000
|
||||
_RESERVED_BYTE = (_PENDING_BYTE + 1)
|
||||
_SHARED_FIRST = (_PENDING_BYTE + 2)
|
||||
_SHARED_SIZE = 510
|
||||
)
|
||||
|
||||
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock _LockLevel) _ErrorCode {
|
||||
// Argument check. SQLite never explicitly requests a pending lock.
|
||||
if eLock != _LOCK_SHARED && eLock != _LOCK_RESERVED && eLock != _LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
|
||||
switch {
|
||||
case file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE:
|
||||
// Connection state check.
|
||||
panic(util.AssertErr())
|
||||
case file.lock == _LOCK_NONE && eLock > _LOCK_SHARED:
|
||||
// We never move from unlocked to anything higher than a shared lock.
|
||||
panic(util.AssertErr())
|
||||
case file.lock != _LOCK_SHARED && eLock == _LOCK_RESERVED:
|
||||
// A shared lock is always held when a reserved lock is requested.
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
// If we already have an equal or more restrictive lock, do nothing.
|
||||
if file.lock >= eLock {
|
||||
return _OK
|
||||
}
|
||||
|
||||
// Do not allow any kind of write-lock on a read-only database.
|
||||
if file.readOnly && eLock >= _LOCK_RESERVED {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
|
||||
switch eLock {
|
||||
case _LOCK_SHARED:
|
||||
// Must be unlocked to get SHARED.
|
||||
if file.lock != _LOCK_NONE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetSharedLock(file.File, file.lockTimeout); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
file.lock = _LOCK_SHARED
|
||||
return _OK
|
||||
|
||||
case _LOCK_RESERVED:
|
||||
// Must be SHARED to get RESERVED.
|
||||
if file.lock != _LOCK_SHARED {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetReservedLock(file.File, file.lockTimeout); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
file.lock = _LOCK_RESERVED
|
||||
return _OK
|
||||
|
||||
case _LOCK_EXCLUSIVE:
|
||||
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
|
||||
if file.lock <= _LOCK_NONE || file.lock >= _LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
|
||||
if file.lock < _LOCK_PENDING {
|
||||
if rc := osGetPendingLock(file.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
file.lock = _LOCK_PENDING
|
||||
}
|
||||
if rc := osGetExclusiveLock(file.File, file.lockTimeout); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
file.lock = _LOCK_EXCLUSIVE
|
||||
return _OK
|
||||
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock _LockLevel) _ErrorCode {
|
||||
// Argument check.
|
||||
if eLock != _LOCK_NONE && eLock != _LOCK_SHARED {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
|
||||
// Connection state check.
|
||||
if file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
// If we don't have a more restrictive lock, do nothing.
|
||||
if file.lock <= eLock {
|
||||
return _OK
|
||||
}
|
||||
|
||||
switch eLock {
|
||||
case _LOCK_SHARED:
|
||||
if rc := osDowngradeLock(file.File, file.lock); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
file.lock = _LOCK_SHARED
|
||||
return _OK
|
||||
|
||||
case _LOCK_NONE:
|
||||
rc := osReleaseLock(file.File, file.lock)
|
||||
file.lock = _LOCK_NONE
|
||||
return rc
|
||||
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
|
||||
file := getVFSFile(ctx, mod, pFile)
|
||||
|
||||
// Connection state check.
|
||||
if file.lock < _LOCK_NONE || file.lock > _LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
var locked bool
|
||||
var rc _ErrorCode
|
||||
if file.lock >= _LOCK_RESERVED {
|
||||
locked = true
|
||||
} else {
|
||||
locked, rc = osCheckReservedLock(file.File)
|
||||
}
|
||||
|
||||
var res uint32
|
||||
if locked {
|
||||
res = 1
|
||||
}
|
||||
util.WriteUint32(mod, pResOut, res)
|
||||
return rc
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock.
|
||||
return osWriteLock(file, _PENDING_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
// Test the RESERVED lock.
|
||||
return osCheckLock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/go-sqlite3/internal/vfs"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
@@ -25,67 +25,72 @@ var (
|
||||
Path string // Path to load the binary from.
|
||||
)
|
||||
|
||||
var sqlite3 struct {
|
||||
var instance struct {
|
||||
runtime wazero.Runtime
|
||||
compiled wazero.CompiledModule
|
||||
err error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func instantiateModule() (*module, error) {
|
||||
func instantiateSQLite() (*sqlite, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
sqlite3.once.Do(compileModule)
|
||||
if sqlite3.err != nil {
|
||||
return nil, sqlite3.err
|
||||
instance.once.Do(compileSQLite)
|
||||
if instance.err != nil {
|
||||
return nil, instance.err
|
||||
}
|
||||
|
||||
cfg := wazero.NewModuleConfig()
|
||||
|
||||
mod, err := sqlite3.runtime.InstantiateModule(ctx, sqlite3.compiled, cfg)
|
||||
mod, err := instance.runtime.InstantiateModule(ctx, instance.compiled, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newModule(mod)
|
||||
return newSQLite(mod)
|
||||
}
|
||||
|
||||
func compileModule() {
|
||||
func compileSQLite() {
|
||||
ctx := context.Background()
|
||||
sqlite3.runtime = wazero.NewRuntime(ctx)
|
||||
instance.runtime = wazero.NewRuntime(ctx)
|
||||
|
||||
env := vfs.Export(sqlite3.runtime.NewHostModuleBuilder("env"))
|
||||
_, sqlite3.err = env.Instantiate(ctx)
|
||||
if sqlite3.err != nil {
|
||||
env := instance.runtime.NewHostModuleBuilder("env")
|
||||
env = vfs.ExportHostFunctions(env)
|
||||
env = exportHostFunctions(env)
|
||||
_, instance.err = env.Instantiate(ctx)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
bin := Binary
|
||||
if bin == nil && Path != "" {
|
||||
bin, sqlite3.err = os.ReadFile(Path)
|
||||
if sqlite3.err != nil {
|
||||
bin, instance.err = os.ReadFile(Path)
|
||||
if instance.err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if bin == nil {
|
||||
sqlite3.err = util.BinaryErr
|
||||
instance.err = util.BinaryErr
|
||||
return
|
||||
}
|
||||
|
||||
sqlite3.compiled, sqlite3.err = sqlite3.runtime.CompileModule(ctx, bin)
|
||||
instance.compiled, instance.err = instance.runtime.CompileModule(ctx, bin)
|
||||
}
|
||||
|
||||
type module struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
vfs io.Closer
|
||||
api sqliteAPI
|
||||
arg []uint64
|
||||
type sqlite struct {
|
||||
ctx context.Context
|
||||
mod api.Module
|
||||
closer io.Closer
|
||||
api sqliteAPI
|
||||
stack [8]uint64
|
||||
}
|
||||
|
||||
func newModule(mod api.Module) (m *module, err error) {
|
||||
m = &module{}
|
||||
m.mod = mod
|
||||
m.ctx, m.vfs = vfs.Context(context.Background())
|
||||
type sqliteKey struct{}
|
||||
|
||||
func newSQLite(mod api.Module) (sqlt *sqlite, err error) {
|
||||
sqlt = new(sqlite)
|
||||
sqlt.ctx, sqlt.closer = util.NewContext(context.Background())
|
||||
sqlt.ctx = context.WithValue(sqlt.ctx, sqliteKey{}, sqlt)
|
||||
sqlt.mod = mod
|
||||
|
||||
getFun := func(name string) api.Function {
|
||||
f := mod.ExportedFunction(name)
|
||||
@@ -105,7 +110,7 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
return util.ReadUint32(mod, uint32(g.Get()))
|
||||
}
|
||||
|
||||
m.api = sqliteAPI{
|
||||
sqlt.api = sqliteAPI{
|
||||
free: getFun("free"),
|
||||
malloc: getFun("malloc"),
|
||||
destructor: getVal("malloc_destructor"),
|
||||
@@ -139,9 +144,6 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
columnText: getFun("sqlite3_column_text"),
|
||||
columnBlob: getFun("sqlite3_column_blob"),
|
||||
columnBytes: getFun("sqlite3_column_bytes"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
blobOpen: getFun("sqlite3_blob_open"),
|
||||
blobClose: getFun("sqlite3_blob_close"),
|
||||
blobReopen: getFun("sqlite3_blob_reopen"),
|
||||
@@ -153,21 +155,48 @@ func newModule(mod api.Module) (m *module, err error) {
|
||||
backupFinish: getFun("sqlite3_backup_finish"),
|
||||
backupRemaining: getFun("sqlite3_backup_remaining"),
|
||||
backupPageCount: getFun("sqlite3_backup_pagecount"),
|
||||
interrupt: getVal("sqlite3_interrupt_offset"),
|
||||
changes: getFun("sqlite3_changes64"),
|
||||
lastRowid: getFun("sqlite3_last_insert_rowid"),
|
||||
autocommit: getFun("sqlite3_get_autocommit"),
|
||||
anyCollation: getFun("sqlite3_anycollseq_init"),
|
||||
createCollation: getFun("sqlite3_create_collation_go"),
|
||||
createFunction: getFun("sqlite3_create_function_go"),
|
||||
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
|
||||
createWindow: getFun("sqlite3_create_window_function_go"),
|
||||
aggregateCtx: getFun("sqlite3_aggregate_context"),
|
||||
userData: getFun("sqlite3_user_data"),
|
||||
setAuxData: getFun("sqlite3_set_auxdata_go"),
|
||||
getAuxData: getFun("sqlite3_get_auxdata"),
|
||||
valueType: getFun("sqlite3_value_type"),
|
||||
valueInteger: getFun("sqlite3_value_int64"),
|
||||
valueFloat: getFun("sqlite3_value_double"),
|
||||
valueText: getFun("sqlite3_value_text"),
|
||||
valueBlob: getFun("sqlite3_value_blob"),
|
||||
valueBytes: getFun("sqlite3_value_bytes"),
|
||||
resultNull: getFun("sqlite3_result_null"),
|
||||
resultInteger: getFun("sqlite3_result_int64"),
|
||||
resultFloat: getFun("sqlite3_result_double"),
|
||||
resultText: getFun("sqlite3_result_text64"),
|
||||
resultBlob: getFun("sqlite3_result_blob64"),
|
||||
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
|
||||
resultError: getFun("sqlite3_result_error"),
|
||||
resultErrorCode: getFun("sqlite3_result_error_code"),
|
||||
resultErrorMem: getFun("sqlite3_result_error_nomem"),
|
||||
resultErrorBig: getFun("sqlite3_result_error_toobig"),
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
return sqlt, nil
|
||||
}
|
||||
|
||||
func (m *module) close() error {
|
||||
err := m.mod.Close(m.ctx)
|
||||
m.vfs.Close()
|
||||
func (sqlt *sqlite) close() error {
|
||||
err := sqlt.mod.Close(sqlt.ctx)
|
||||
sqlt.closer.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
|
||||
if rc == _OK {
|
||||
return nil
|
||||
}
|
||||
@@ -178,22 +207,17 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
|
||||
var r []uint64
|
||||
|
||||
r = m.call(m.api.errstr, rc)
|
||||
if r != nil {
|
||||
err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
|
||||
if r := sqlt.call(sqlt.api.errstr, rc); r != 0 {
|
||||
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
r = m.call(m.api.errmsg, uint64(handle))
|
||||
if r != nil {
|
||||
err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
|
||||
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
|
||||
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
|
||||
}
|
||||
|
||||
if sql != nil {
|
||||
r = m.call(m.api.erroff, uint64(handle))
|
||||
if r != nil && r[0] != math.MaxUint32 {
|
||||
err.sql = sql[0][r[0]:]
|
||||
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
|
||||
err.sql = sql[0][r:]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,61 +228,60 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
|
||||
return &err
|
||||
}
|
||||
|
||||
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
|
||||
m.arg = append(m.arg[:0], params...)
|
||||
r, err := fn.Call(m.ctx, m.arg...)
|
||||
func (sqlt *sqlite) call(fn api.Function, params ...uint64) uint64 {
|
||||
copy(sqlt.stack[:], params)
|
||||
err := fn.CallWithStack(sqlt.ctx, sqlt.stack[:])
|
||||
if err != nil {
|
||||
// The module closed or panicked; release resources.
|
||||
m.vfs.Close()
|
||||
sqlt.closer.Close()
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
return sqlt.stack[0]
|
||||
}
|
||||
|
||||
func (m *module) free(ptr uint32) {
|
||||
func (sqlt *sqlite) free(ptr uint32) {
|
||||
if ptr == 0 {
|
||||
return
|
||||
}
|
||||
m.call(m.api.free, uint64(ptr))
|
||||
sqlt.call(sqlt.api.free, uint64(ptr))
|
||||
}
|
||||
|
||||
func (m *module) new(size uint64) uint32 {
|
||||
func (sqlt *sqlite) new(size uint64) uint32 {
|
||||
if size > _MAX_ALLOCATION_SIZE {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
r := m.call(m.api.malloc, size)
|
||||
ptr := uint32(r[0])
|
||||
ptr := uint32(sqlt.call(sqlt.api.malloc, size))
|
||||
if ptr == 0 && size != 0 {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newBytes(b []byte) uint32 {
|
||||
func (sqlt *sqlite) newBytes(b []byte) uint32 {
|
||||
if b == nil {
|
||||
return 0
|
||||
}
|
||||
ptr := m.new(uint64(len(b)))
|
||||
util.WriteBytes(m.mod, ptr, b)
|
||||
ptr := sqlt.new(uint64(len(b)))
|
||||
util.WriteBytes(sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newString(s string) uint32 {
|
||||
ptr := m.new(uint64(len(s) + 1))
|
||||
util.WriteString(m.mod, ptr, s)
|
||||
func (sqlt *sqlite) newString(s string) uint32 {
|
||||
ptr := sqlt.new(uint64(len(s) + 1))
|
||||
util.WriteString(sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (m *module) newArena(size uint64) arena {
|
||||
func (sqlt *sqlite) newArena(size uint64) arena {
|
||||
return arena{
|
||||
m: m,
|
||||
base: m.new(size),
|
||||
sqlt: sqlt,
|
||||
size: uint32(size),
|
||||
base: sqlt.new(size),
|
||||
}
|
||||
}
|
||||
|
||||
type arena struct {
|
||||
m *module
|
||||
sqlt *sqlite
|
||||
ptrs []uint32
|
||||
base uint32
|
||||
next uint32
|
||||
@@ -266,17 +289,17 @@ type arena struct {
|
||||
}
|
||||
|
||||
func (a *arena) free() {
|
||||
if a.m == nil {
|
||||
if a.sqlt == nil {
|
||||
return
|
||||
}
|
||||
a.reset()
|
||||
a.m.free(a.base)
|
||||
a.m = nil
|
||||
a.sqlt.free(a.base)
|
||||
a.sqlt = nil
|
||||
}
|
||||
|
||||
func (a *arena) reset() {
|
||||
for _, ptr := range a.ptrs {
|
||||
a.m.free(ptr)
|
||||
a.sqlt.free(ptr)
|
||||
}
|
||||
a.ptrs = nil
|
||||
a.next = 0
|
||||
@@ -288,7 +311,7 @@ func (a *arena) new(size uint64) uint32 {
|
||||
a.next += uint32(size)
|
||||
return ptr
|
||||
}
|
||||
ptr := a.m.new(size)
|
||||
ptr := a.sqlt.new(size)
|
||||
a.ptrs = append(a.ptrs, ptr)
|
||||
return ptr
|
||||
}
|
||||
@@ -298,13 +321,13 @@ func (a *arena) bytes(b []byte) uint32 {
|
||||
return 0
|
||||
}
|
||||
ptr := a.new(uint64(len(b)))
|
||||
util.WriteBytes(a.m.mod, ptr, b)
|
||||
util.WriteBytes(a.sqlt.mod, ptr, b)
|
||||
return ptr
|
||||
}
|
||||
|
||||
func (a *arena) string(s string) uint32 {
|
||||
ptr := a.new(uint64(len(s) + 1))
|
||||
util.WriteString(a.m.mod, ptr, s)
|
||||
util.WriteString(a.sqlt.mod, ptr, s)
|
||||
return ptr
|
||||
}
|
||||
|
||||
@@ -324,10 +347,10 @@ type sqliteAPI struct {
|
||||
step api.Function
|
||||
exec api.Function
|
||||
clearBindings api.Function
|
||||
bindNull api.Function
|
||||
bindCount api.Function
|
||||
bindIndex api.Function
|
||||
bindName api.Function
|
||||
bindNull api.Function
|
||||
bindInteger api.Function
|
||||
bindFloat api.Function
|
||||
bindText api.Function
|
||||
@@ -341,9 +364,6 @@ type sqliteAPI struct {
|
||||
columnText api.Function
|
||||
columnBlob api.Function
|
||||
columnBytes api.Function
|
||||
autocommit api.Function
|
||||
lastRowid api.Function
|
||||
changes api.Function
|
||||
blobOpen api.Function
|
||||
blobClose api.Function
|
||||
blobReopen api.Function
|
||||
@@ -355,6 +375,33 @@ type sqliteAPI struct {
|
||||
backupFinish api.Function
|
||||
backupRemaining api.Function
|
||||
backupPageCount api.Function
|
||||
changes api.Function
|
||||
lastRowid api.Function
|
||||
autocommit api.Function
|
||||
anyCollation api.Function
|
||||
createCollation api.Function
|
||||
createFunction api.Function
|
||||
createAggregate api.Function
|
||||
createWindow api.Function
|
||||
aggregateCtx api.Function
|
||||
userData api.Function
|
||||
setAuxData api.Function
|
||||
getAuxData api.Function
|
||||
valueType api.Function
|
||||
valueInteger api.Function
|
||||
valueFloat api.Function
|
||||
valueText api.Function
|
||||
valueBlob api.Function
|
||||
valueBytes api.Function
|
||||
resultNull api.Function
|
||||
resultInteger api.Function
|
||||
resultFloat api.Function
|
||||
resultText api.Function
|
||||
resultBlob api.Function
|
||||
resultZeroBlob api.Function
|
||||
resultError api.Function
|
||||
resultErrorCode api.Function
|
||||
resultErrorMem api.Function
|
||||
resultErrorBig api.Function
|
||||
destructor uint32
|
||||
interrupt uint32
|
||||
}
|
||||
1
sqlite3/.gitignore
vendored
1
sqlite3/.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
ext/
|
||||
sqlite3.c
|
||||
sqlite3.h
|
||||
sqlite3ext.h
|
||||
@@ -3,29 +3,33 @@ set -euo pipefail
|
||||
|
||||
cd -P -- "$(dirname -- "$0")"
|
||||
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3410200.zip"
|
||||
curl -#OL "https://sqlite.org/2023/sqlite-amalgamation-3420000.zip"
|
||||
unzip -d . sqlite-amalgamation-*.zip
|
||||
mv sqlite-amalgamation-*/sqlite3* .
|
||||
rm -rf sqlite-amalgamation-*
|
||||
|
||||
cat *.patch | patch
|
||||
|
||||
mkdir -p ext/
|
||||
cd ext/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uuid.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/base64.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/anycollseq.c"
|
||||
cd ~-
|
||||
|
||||
cd ../internal/vfs/tests/mptest/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/mptest/multiwrite01.test"
|
||||
cd ../vfs/tests/mptest/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash01.test"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.subtest"
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
|
||||
cd ~-
|
||||
|
||||
cd ../internal/vfs/tests/speedtest1/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.41.2/test/speedtest1.c"
|
||||
cd ../vfs/tests/speedtest1/testdata/
|
||||
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
|
||||
cd ~-
|
||||
1
sqlite3/ext/.gitignore
vendored
1
sqlite3/ext/.gitignore
vendored
@@ -1 +0,0 @@
|
||||
*.c
|
||||
40
sqlite3/func.c
Normal file
40
sqlite3/func.c
Normal file
@@ -0,0 +1,40 @@
|
||||
#include <string.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_compare(void *, int, const void *, int, const void *);
|
||||
void go_func(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_step(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_final(sqlite3_context *);
|
||||
void go_value(sqlite3_context *);
|
||||
void go_inverse(sqlite3_context *, int, sqlite3_value **);
|
||||
void go_destroy(void *);
|
||||
|
||||
int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) {
|
||||
return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp,
|
||||
go_func, NULL, NULL, go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName,
|
||||
int nArg, int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, NULL, NULL,
|
||||
go_destroy);
|
||||
}
|
||||
|
||||
int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg,
|
||||
int flags, void *pApp) {
|
||||
return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags,
|
||||
pApp, go_step, go_final, go_value,
|
||||
go_inverse, go_destroy);
|
||||
}
|
||||
|
||||
void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) {
|
||||
sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy);
|
||||
}
|
||||
14
sqlite3/locking_mode.patch
Normal file
14
sqlite3/locking_mode.patch
Normal file
@@ -0,0 +1,14 @@
|
||||
# Use exclusive locking mode for WAL databases with v1 VFSes.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -63210,7 +63210,9 @@
|
||||
SQLITE_PRIVATE int sqlite3PagerWalSupported(Pager *pPager){
|
||||
const sqlite3_io_methods *pMethods = pPager->fd->pMethods;
|
||||
if( pPager->noLock ) return 0;
|
||||
- return pPager->exclusiveMode || (pMethods->iVersion>=2 && pMethods->xShmMap);
|
||||
+ if( pMethods->iVersion>=2 && pMethods->xShmMap ) return 1;
|
||||
+ pPager->exclusiveMode = 1;
|
||||
+ return 1;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -1,25 +1,21 @@
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
// Amalgamation
|
||||
#include "sqlite3.c"
|
||||
//
|
||||
#include "os.c"
|
||||
//
|
||||
// VFS
|
||||
#include "vfs.c"
|
||||
// Extensions
|
||||
#include "ext/anycollseq.c"
|
||||
#include "ext/base64.c"
|
||||
#include "ext/decimal.c"
|
||||
#include "ext/regexp.c"
|
||||
#include "ext/series.c"
|
||||
#include "ext/uint.c"
|
||||
#include "ext/uuid.c"
|
||||
#include "func.c"
|
||||
#include "time.c"
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
size_t sqlite3_interrupt_offset = offsetof(sqlite3, u1.isInterrupted);
|
||||
|
||||
int sqlite3_os_init() {
|
||||
return sqlite3_vfs_register(os_vfs(), /*default=*/true);
|
||||
}
|
||||
|
||||
__attribute__((constructor)) void init() {
|
||||
sqlite3_initialize();
|
||||
sqlite3_auto_extension((void (*)(void))sqlite3_base_init);
|
||||
|
||||
26
sqlite3/open_memory.patch
Normal file
26
sqlite3/open_memory.patch
Normal file
@@ -0,0 +1,26 @@
|
||||
# Allow the VFS to force memory journal mode
|
||||
# regardless of SQLITE_OMIT_DESERIALIZE.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -60425,11 +60425,7 @@
|
||||
int rc = SQLITE_OK; /* Return code */
|
||||
int tempFile = 0; /* True for temp files (incl. in-memory files) */
|
||||
int memDb = 0; /* True if this is an in-memory file */
|
||||
-#ifndef SQLITE_OMIT_DESERIALIZE
|
||||
int memJM = 0; /* Memory journal mode */
|
||||
-#else
|
||||
-# define memJM 0
|
||||
-#endif
|
||||
int readOnly = 0; /* True if this is a read-only file */
|
||||
int journalFileSize; /* Bytes to allocate for each journal fd */
|
||||
char *zPathname = 0; /* Full path to database file */
|
||||
@@ -60628,9 +60624,7 @@
|
||||
int fout = 0; /* VFS flags returned by xOpen() */
|
||||
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
|
||||
assert( !memDb );
|
||||
-#ifndef SQLITE_OMIT_DESERIALIZE
|
||||
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
|
||||
-#endif
|
||||
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;
|
||||
|
||||
/* If the file was successfully opened for read/write access,
|
||||
97
sqlite3/os.c
97
sqlite3/os.c
@@ -1,97 +0,0 @@
|
||||
#include <time.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int os_localtime(struct tm *, sqlite3_int64);
|
||||
|
||||
int os_randomness(sqlite3_vfs *, int nByte, char *zOut);
|
||||
int os_sleep(sqlite3_vfs *, int microseconds);
|
||||
int os_current_time(sqlite3_vfs *, double *);
|
||||
int os_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
|
||||
|
||||
int os_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
|
||||
int *pOutFlags);
|
||||
int os_delete(sqlite3_vfs *, const char *zName, int syncDir);
|
||||
int os_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
|
||||
int os_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
|
||||
|
||||
struct os_file {
|
||||
sqlite3_file base;
|
||||
int handle;
|
||||
};
|
||||
|
||||
static_assert(offsetof(struct os_file, handle) == 4, "Unexpected offset");
|
||||
|
||||
int os_close(sqlite3_file *);
|
||||
int os_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int os_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int os_truncate(sqlite3_file *, sqlite3_int64 size);
|
||||
int os_sync(sqlite3_file *, int flags);
|
||||
int os_file_size(sqlite3_file *, sqlite3_int64 *pSize);
|
||||
int os_file_control(sqlite3_file *, int op, void *pArg);
|
||||
int os_sector_size(sqlite3_file *file);
|
||||
int os_device_characteristics(sqlite3_file *file);
|
||||
|
||||
int os_lock(sqlite3_file *, int eLock);
|
||||
int os_unlock(sqlite3_file *, int eLock);
|
||||
int os_check_reserved_lock(sqlite3_file *, int *pResOut);
|
||||
|
||||
static int os_file_control_w(sqlite3_file *file, int op, void *pArg) {
|
||||
struct os_file *pFile = (struct os_file *)file;
|
||||
if (op == SQLITE_FCNTL_VFSNAME) {
|
||||
*(char **)pArg = sqlite3_mprintf("%s", "os");
|
||||
return SQLITE_OK;
|
||||
}
|
||||
return os_file_control(file, op, pArg);
|
||||
}
|
||||
|
||||
static int os_open_w(sqlite3_vfs *vfs, sqlite3_filename zName,
|
||||
sqlite3_file *file, int flags, int *pOutFlags) {
|
||||
static const sqlite3_io_methods os_io = {
|
||||
.iVersion = 1,
|
||||
.xClose = os_close,
|
||||
.xRead = os_read,
|
||||
.xWrite = os_write,
|
||||
.xTruncate = os_truncate,
|
||||
.xSync = os_sync,
|
||||
.xFileSize = os_file_size,
|
||||
.xLock = os_lock,
|
||||
.xUnlock = os_unlock,
|
||||
.xCheckReservedLock = os_check_reserved_lock,
|
||||
.xFileControl = os_file_control_w,
|
||||
.xSectorSize = os_sector_size,
|
||||
.xDeviceCharacteristics = os_device_characteristics,
|
||||
};
|
||||
memset(file, 0, sizeof(struct os_file));
|
||||
int rc = os_open(vfs, zName, file, flags, pOutFlags);
|
||||
if (rc) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
file->pMethods = &os_io;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
sqlite3_vfs *os_vfs() {
|
||||
static sqlite3_vfs os_vfs = {
|
||||
.iVersion = 2,
|
||||
.szOsFile = sizeof(struct os_file),
|
||||
.mxPathname = 512,
|
||||
.zName = "os",
|
||||
|
||||
.xOpen = os_open_w,
|
||||
.xDelete = os_delete,
|
||||
.xAccess = os_access,
|
||||
.xFullPathname = os_full_pathname,
|
||||
|
||||
.xRandomness = os_randomness,
|
||||
.xSleep = os_sleep,
|
||||
.xCurrentTime = os_current_time,
|
||||
.xCurrentTimeInt64 = os_current_time_64,
|
||||
};
|
||||
return &os_vfs;
|
||||
}
|
||||
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
|
||||
return os_localtime(pTm, (sqlite3_int64)*pTime);
|
||||
}
|
||||
@@ -29,18 +29,18 @@
|
||||
#define SQLITE_USE_ALLOCA
|
||||
|
||||
// Other Options
|
||||
// #define SQLITE_ALLOW_URI_AUTHORITY
|
||||
#define SQLITE_ALLOW_URI_AUTHORITY
|
||||
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
|
||||
#define SQLITE_ENABLE_ATOMIC_WRITE
|
||||
#define SQLITE_OMIT_DESERIALIZE
|
||||
|
||||
// Because WASM does not support shared memory,
|
||||
// SQLite disables WAL for WASM builds.
|
||||
// We set the default locking mode to EXCLUSIVE instead.
|
||||
// We patch SQLite to use exclusive locking mode instead.
|
||||
// https://www.sqlite.org/wal.html#noshm
|
||||
#undef SQLITE_OMIT_WAL
|
||||
#ifndef SQLITE_DEFAULT_LOCKING_MODE
|
||||
#define SQLITE_DEFAULT_LOCKING_MODE 1
|
||||
#endif
|
||||
|
||||
// Recommended Extensions
|
||||
// Amalgamated Extensions
|
||||
|
||||
#define SQLITE_ENABLE_MATH_FUNCTIONS 1
|
||||
#define SQLITE_ENABLE_JSON1 1
|
||||
@@ -52,8 +52,8 @@
|
||||
#define SQLITE_ENABLE_GEOPOLY 1
|
||||
|
||||
// Session Extension
|
||||
// #define SQLITE_ENABLE_SESSION 1
|
||||
// #define SQLITE_ENABLE_PREUPDATE_HOOK 1
|
||||
// #define SQLITE_ENABLE_SESSION
|
||||
// #define SQLITE_ENABLE_PREUPDATE_HOOK
|
||||
|
||||
// Implemented in Go.
|
||||
// Implemented in vfs.c.
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime);
|
||||
137
sqlite3/vfs.c
Normal file
137
sqlite3/vfs.c
Normal file
@@ -0,0 +1,137 @@
|
||||
#include <time.h>
|
||||
|
||||
#include "sqlite3.h"
|
||||
|
||||
int go_localtime(struct tm *, sqlite3_int64);
|
||||
int go_vfs_find(const char *zVfsName);
|
||||
|
||||
int go_randomness(sqlite3_vfs *, int nByte, char *zOut);
|
||||
int go_sleep(sqlite3_vfs *, int microseconds);
|
||||
int go_current_time(sqlite3_vfs *, double *);
|
||||
int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *);
|
||||
|
||||
int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags,
|
||||
int *pOutFlags);
|
||||
int go_delete(sqlite3_vfs *, const char *zName, int syncDir);
|
||||
int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut);
|
||||
int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut);
|
||||
|
||||
int go_close(sqlite3_file *);
|
||||
int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst);
|
||||
int go_truncate(sqlite3_file *, sqlite3_int64 size);
|
||||
int go_sync(sqlite3_file *, int flags);
|
||||
int go_file_size(sqlite3_file *, sqlite3_int64 *pSize);
|
||||
int go_file_control(sqlite3_file *, int op, void *pArg);
|
||||
int go_sector_size(sqlite3_file *file);
|
||||
int go_device_characteristics(sqlite3_file *file);
|
||||
|
||||
int go_lock(sqlite3_file *, int eLock);
|
||||
int go_unlock(sqlite3_file *, int eLock);
|
||||
int go_check_reserved_lock(sqlite3_file *, int *pResOut);
|
||||
|
||||
static int go_open_wrapper(sqlite3_vfs *vfs, sqlite3_filename zName,
|
||||
sqlite3_file *file, int flags, int *pOutFlags) {
|
||||
static const sqlite3_io_methods os_io = {
|
||||
.iVersion = 1,
|
||||
.xClose = go_close,
|
||||
.xRead = go_read,
|
||||
.xWrite = go_write,
|
||||
.xTruncate = go_truncate,
|
||||
.xSync = go_sync,
|
||||
.xFileSize = go_file_size,
|
||||
.xLock = go_lock,
|
||||
.xUnlock = go_unlock,
|
||||
.xCheckReservedLock = go_check_reserved_lock,
|
||||
.xFileControl = go_file_control,
|
||||
.xSectorSize = go_sector_size,
|
||||
.xDeviceCharacteristics = go_device_characteristics,
|
||||
};
|
||||
memset(file, 0, vfs->szOsFile);
|
||||
int rc = go_open(vfs, zName, file, flags, pOutFlags);
|
||||
if (rc) {
|
||||
return rc;
|
||||
}
|
||||
file->pMethods = &os_io;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
struct go_file {
|
||||
sqlite3_file base;
|
||||
int handle;
|
||||
};
|
||||
|
||||
int sqlite3_os_init() {
|
||||
static sqlite3_vfs os_vfs = {
|
||||
.iVersion = 2,
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = "os",
|
||||
|
||||
.xOpen = go_open_wrapper,
|
||||
.xDelete = go_delete,
|
||||
.xAccess = go_access,
|
||||
.xFullPathname = go_full_pathname,
|
||||
|
||||
.xRandomness = go_randomness,
|
||||
.xSleep = go_sleep,
|
||||
.xCurrentTime = go_current_time,
|
||||
.xCurrentTimeInt64 = go_current_time_64,
|
||||
};
|
||||
return sqlite3_vfs_register(&os_vfs, /*default=*/true);
|
||||
}
|
||||
|
||||
sqlite3_destructor_type malloc_destructor = &free;
|
||||
|
||||
int localtime_s(struct tm *const pTm, time_t const *const pTime) {
|
||||
return go_localtime(pTm, (sqlite3_int64)*pTime);
|
||||
}
|
||||
|
||||
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
|
||||
if (zVfsName) {
|
||||
static sqlite3_vfs *go_vfs_list;
|
||||
sqlite3_vfs *found = NULL;
|
||||
for (sqlite3_vfs **next = &go_vfs_list; *next;) {
|
||||
sqlite3_vfs *it = *next;
|
||||
if (go_vfs_find(it->zName)) {
|
||||
if (!strcmp(zVfsName, it->zName)) found = it;
|
||||
next = &it->pNext;
|
||||
} else {
|
||||
*next = it->pNext;
|
||||
free(it);
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
return found;
|
||||
}
|
||||
if (go_vfs_find(zVfsName)) {
|
||||
sqlite3_vfs *prev = go_vfs_list;
|
||||
go_vfs_list = malloc(sizeof(sqlite3_vfs) + strlen(zVfsName) + 1);
|
||||
char *name = (char *)(go_vfs_list + 1);
|
||||
strcpy(name, zVfsName);
|
||||
*go_vfs_list = (sqlite3_vfs){
|
||||
.iVersion = 2,
|
||||
.szOsFile = sizeof(struct go_file),
|
||||
.mxPathname = 512,
|
||||
.zName = name,
|
||||
.pNext = prev,
|
||||
|
||||
.xOpen = go_open_wrapper,
|
||||
.xDelete = go_delete,
|
||||
.xAccess = go_access,
|
||||
.xFullPathname = go_full_pathname,
|
||||
|
||||
.xRandomness = go_randomness,
|
||||
.xSleep = go_sleep,
|
||||
.xCurrentTime = go_current_time,
|
||||
.xCurrentTimeInt64 = go_current_time_64,
|
||||
};
|
||||
return go_vfs_list;
|
||||
}
|
||||
}
|
||||
return sqlite3_vfs_find_orig(zVfsName);
|
||||
}
|
||||
|
||||
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3_vfs, zName) == 16, "Unexpected offset");
|
||||
static_assert(offsetof(sqlite3, u1.isInterrupted) == 288, "Unexpected offset");
|
||||
12
sqlite3/vfs_find.patch
Normal file
12
sqlite3/vfs_find.patch
Normal file
@@ -0,0 +1,12 @@
|
||||
# Wrap sqlite3_vfs_find.
|
||||
--- sqlite3.c.orig
|
||||
+++ sqlite3.c
|
||||
@@ -25394,7 +25394,7 @@
|
||||
** Locate a VFS by name. If no name is given, simply return the
|
||||
** first VFS on the list.
|
||||
*/
|
||||
-SQLITE_API sqlite3_vfs *sqlite3_vfs_find(const char *zVfs){
|
||||
+SQLITE_API sqlite3_vfs *sqlite3_vfs_find_orig(const char *zVfs){
|
||||
sqlite3_vfs *pVfs = 0;
|
||||
#if SQLITE_THREADSAFE
|
||||
sqlite3_mutex *mutex;
|
||||
@@ -12,76 +12,75 @@ func init() {
|
||||
Path = "./embed/sqlite3.wasm"
|
||||
}
|
||||
|
||||
func TestConn_error_OOM(t *testing.T) {
|
||||
func Test_sqlite_error_OOM(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
m.error(uint64(NOMEM), 0)
|
||||
sqlite.error(uint64(NOMEM), 0)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_call_nil(t *testing.T) {
|
||||
func Test_sqlite_call_closed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
sqlite.close()
|
||||
|
||||
defer func() { _ = recover() }()
|
||||
m.call(m.api.free)
|
||||
sqlite.call(sqlite.api.free)
|
||||
t.Error("want panic")
|
||||
}
|
||||
|
||||
func TestConn_new(t *testing.T) {
|
||||
func Test_sqlite_new(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
t.Run("MaxUint32", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
m.new(math.MaxUint32)
|
||||
sqlite.new(math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
})
|
||||
|
||||
t.Run("_MAX_ALLOCATION_SIZE", func(t *testing.T) {
|
||||
defer func() { _ = recover() }()
|
||||
m.new(_MAX_ALLOCATION_SIZE)
|
||||
m.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
sqlite.new(_MAX_ALLOCATION_SIZE)
|
||||
t.Error("want panic")
|
||||
})
|
||||
}
|
||||
|
||||
func TestConn_newArena(t *testing.T) {
|
||||
func Test_sqlite_newArena(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
arena := m.newArena(16)
|
||||
arena := sqlite.newArena(16)
|
||||
defer arena.free()
|
||||
|
||||
const title = "Lorem ipsum"
|
||||
|
||||
ptr := arena.string(title)
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
@@ -90,120 +89,133 @@ func TestConn_newArena(t *testing.T) {
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != body {
|
||||
t.Errorf("got %q, want %q", got, body)
|
||||
}
|
||||
|
||||
ptr = arena.bytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("want nullptr")
|
||||
}
|
||||
ptr = arena.bytes([]byte(title))
|
||||
if ptr == 0 {
|
||||
t.Fatalf("got nullptr")
|
||||
}
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(title))); string(got) != title {
|
||||
t.Errorf("got %q, want %q", got, title)
|
||||
}
|
||||
|
||||
arena.free()
|
||||
}
|
||||
|
||||
func TestConn_newBytes(t *testing.T) {
|
||||
func Test_sqlite_newBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newBytes(nil)
|
||||
ptr := sqlite.newBytes(nil)
|
||||
if ptr != 0 {
|
||||
t.Errorf("got %#x, want nullptr", ptr)
|
||||
}
|
||||
|
||||
buf := []byte("sqlite3")
|
||||
ptr = m.newBytes(buf)
|
||||
ptr = sqlite.newBytes(buf)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := buf
|
||||
if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_newString(t *testing.T) {
|
||||
func Test_sqlite_newString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newString("")
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3\000sqlite3"
|
||||
ptr = m.newString(str)
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := str + "\000"
|
||||
if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
if got := util.View(sqlite.mod, ptr, uint64(len(want))); string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_getString(t *testing.T) {
|
||||
func Test_sqlite_getString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
ptr := m.newString("")
|
||||
ptr := sqlite.newString("")
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
str := "sqlite3" + "\000 drop this"
|
||||
ptr = m.newString(str)
|
||||
ptr = sqlite.newString(str)
|
||||
if ptr == 0 {
|
||||
t.Fatal("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
want := "sqlite3"
|
||||
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want {
|
||||
if got := util.ReadString(sqlite.mod, ptr, math.MaxUint32); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if got := util.ReadString(m.mod, ptr, 0); got != "" {
|
||||
if got := util.ReadString(sqlite.mod, ptr, 0); got != "" {
|
||||
t.Errorf("got %q, want empty", got)
|
||||
}
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(m.mod, ptr, uint32(len(want)/2))
|
||||
util.ReadString(sqlite.mod, ptr, uint32(len(want)/2))
|
||||
t.Error("want panic")
|
||||
}()
|
||||
|
||||
func() {
|
||||
defer func() { _ = recover() }()
|
||||
util.ReadString(m.mod, 0, math.MaxUint32)
|
||||
util.ReadString(sqlite.mod, 0, math.MaxUint32)
|
||||
t.Error("want panic")
|
||||
}()
|
||||
}
|
||||
|
||||
func TestConn_free(t *testing.T) {
|
||||
func Test_sqlite_free(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m, err := instantiateModule()
|
||||
sqlite, err := instantiateSQLite()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer m.close()
|
||||
defer sqlite.close()
|
||||
|
||||
m.free(0)
|
||||
sqlite.free(0)
|
||||
|
||||
ptr := m.new(1)
|
||||
ptr := sqlite.new(1)
|
||||
if ptr == 0 {
|
||||
t.Error("got nullptr, want a pointer")
|
||||
}
|
||||
|
||||
m.free(ptr)
|
||||
sqlite.free(ptr)
|
||||
}
|
||||
73
stmt.go
73
stmt.go
@@ -29,7 +29,7 @@ func (s *Stmt) Close() error {
|
||||
r := s.c.call(s.c.api.finalize, uint64(s.handle))
|
||||
|
||||
s.handle = 0
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// Reset resets the prepared statement object.
|
||||
@@ -38,7 +38,7 @@ func (s *Stmt) Close() error {
|
||||
func (s *Stmt) Reset() error {
|
||||
r := s.c.call(s.c.api.reset, uint64(s.handle))
|
||||
s.err = nil
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// ClearBindings resets all bindings on the prepared statement.
|
||||
@@ -46,7 +46,7 @@ func (s *Stmt) Reset() error {
|
||||
// https://www.sqlite.org/c3ref/clear_bindings.html
|
||||
func (s *Stmt) ClearBindings() error {
|
||||
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// Step evaluates the SQL statement.
|
||||
@@ -61,13 +61,13 @@ func (s *Stmt) ClearBindings() error {
|
||||
func (s *Stmt) Step() bool {
|
||||
s.c.checkInterrupt()
|
||||
r := s.c.call(s.c.api.step, uint64(s.handle))
|
||||
if r[0] == _ROW {
|
||||
if r == _ROW {
|
||||
return true
|
||||
}
|
||||
if r[0] == _DONE {
|
||||
if r == _DONE {
|
||||
s.err = nil
|
||||
} else {
|
||||
s.err = s.c.error(r[0])
|
||||
s.err = s.c.error(r)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func (s *Stmt) Exec() error {
|
||||
func (s *Stmt) BindCount() int {
|
||||
r := s.c.call(s.c.api.bindCount,
|
||||
uint64(s.handle))
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// BindIndex returns the index of a parameter in the prepared statement
|
||||
@@ -106,7 +106,7 @@ func (s *Stmt) BindIndex(name string) int {
|
||||
namePtr := s.c.arena.string(name)
|
||||
r := s.c.call(s.c.api.bindIndex,
|
||||
uint64(s.handle), uint64(namePtr))
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// BindName returns the name of a parameter in the prepared statement.
|
||||
@@ -117,7 +117,7 @@ func (s *Stmt) BindName(param int) string {
|
||||
r := s.c.call(s.c.api.bindName,
|
||||
uint64(s.handle), uint64(param))
|
||||
|
||||
ptr := uint32(r[0])
|
||||
ptr := uint32(r)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -131,10 +131,11 @@ func (s *Stmt) BindName(param int) string {
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/bind_blob.html
|
||||
func (s *Stmt) BindBool(param int, value bool) error {
|
||||
var i int64
|
||||
if value {
|
||||
return s.BindInt64(param, 1)
|
||||
i = 1
|
||||
}
|
||||
return s.BindInt64(param, 0)
|
||||
return s.BindInt64(param, i)
|
||||
}
|
||||
|
||||
// BindInt binds an int to the prepared statement.
|
||||
@@ -152,7 +153,7 @@ func (s *Stmt) BindInt(param int, value int) error {
|
||||
func (s *Stmt) BindInt64(param int, value int64) error {
|
||||
r := s.c.call(s.c.api.bindInteger,
|
||||
uint64(s.handle), uint64(param), uint64(value))
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindFloat binds a float64 to the prepared statement.
|
||||
@@ -162,7 +163,7 @@ func (s *Stmt) BindInt64(param int, value int64) error {
|
||||
func (s *Stmt) BindFloat(param int, value float64) error {
|
||||
r := s.c.call(s.c.api.bindFloat,
|
||||
uint64(s.handle), uint64(param), math.Float64bits(value))
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindText binds a string to the prepared statement.
|
||||
@@ -175,7 +176,7 @@ func (s *Stmt) BindText(param int, value string) error {
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(value)),
|
||||
uint64(s.c.api.destructor), _UTF8)
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindBlob binds a []byte to the prepared statement.
|
||||
@@ -189,7 +190,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(value)),
|
||||
uint64(s.c.api.destructor))
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
|
||||
@@ -199,7 +200,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
|
||||
func (s *Stmt) BindZeroBlob(param int, n int64) error {
|
||||
r := s.c.call(s.c.api.bindZeroBlob,
|
||||
uint64(s.handle), uint64(param), uint64(n))
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindNull binds a NULL to the prepared statement.
|
||||
@@ -209,7 +210,7 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
|
||||
func (s *Stmt) BindNull(param int) error {
|
||||
r := s.c.call(s.c.api.bindNull,
|
||||
uint64(s.handle), uint64(param))
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// BindTime binds a [time.Time] to the prepared statement.
|
||||
@@ -244,7 +245,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
uint64(s.handle), uint64(param),
|
||||
uint64(ptr), uint64(len(buf)),
|
||||
uint64(s.c.api.destructor), _UTF8)
|
||||
return s.c.error(r[0])
|
||||
return s.c.error(r)
|
||||
}
|
||||
|
||||
// ColumnCount returns the number of columns in a result set.
|
||||
@@ -253,7 +254,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
|
||||
func (s *Stmt) ColumnCount() int {
|
||||
r := s.c.call(s.c.api.columnCount,
|
||||
uint64(s.handle))
|
||||
return int(r[0])
|
||||
return int(r)
|
||||
}
|
||||
|
||||
// ColumnName returns the name of the result column.
|
||||
@@ -264,7 +265,7 @@ func (s *Stmt) ColumnName(col int) string {
|
||||
r := s.c.call(s.c.api.columnName,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
ptr := uint32(r[0])
|
||||
ptr := uint32(r)
|
||||
if ptr == 0 {
|
||||
panic(util.OOMErr)
|
||||
}
|
||||
@@ -278,7 +279,7 @@ func (s *Stmt) ColumnName(col int) string {
|
||||
func (s *Stmt) ColumnType(col int) Datatype {
|
||||
r := s.c.call(s.c.api.columnType,
|
||||
uint64(s.handle), uint64(col))
|
||||
return Datatype(r[0])
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
// ColumnBool returns the value of the result column as a bool.
|
||||
@@ -310,7 +311,7 @@ func (s *Stmt) ColumnInt(col int) int {
|
||||
func (s *Stmt) ColumnInt64(col int) int64 {
|
||||
r := s.c.call(s.c.api.columnInteger,
|
||||
uint64(s.handle), uint64(col))
|
||||
return int64(r[0])
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// ColumnFloat returns the value of the result column as a float64.
|
||||
@@ -320,7 +321,7 @@ func (s *Stmt) ColumnInt64(col int) int64 {
|
||||
func (s *Stmt) ColumnFloat(col int) float64 {
|
||||
r := s.c.call(s.c.api.columnFloat,
|
||||
uint64(s.handle), uint64(col))
|
||||
return math.Float64frombits(r[0])
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// ColumnTime returns the value of the result column as a [time.Time].
|
||||
@@ -374,18 +375,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
|
||||
func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnText,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
ptr := uint32(r[0])
|
||||
if ptr == 0 {
|
||||
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
r = s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
return util.View(s.c.mod, ptr, r[0])
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
// ColumnRawBlob returns the value of the result column as a []byte.
|
||||
@@ -397,18 +387,19 @@ func (s *Stmt) ColumnRawText(col int) []byte {
|
||||
func (s *Stmt) ColumnRawBlob(col int) []byte {
|
||||
r := s.c.call(s.c.api.columnBlob,
|
||||
uint64(s.handle), uint64(col))
|
||||
return s.columnRawBytes(col, uint32(r))
|
||||
}
|
||||
|
||||
ptr := uint32(r[0])
|
||||
func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r[0])
|
||||
r := s.c.call(s.c.api.errcode, uint64(s.c.handle))
|
||||
s.err = s.c.error(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
r = s.c.call(s.c.api.columnBytes,
|
||||
r := s.c.call(s.c.api.columnBytes,
|
||||
uint64(s.handle), uint64(col))
|
||||
|
||||
return util.View(s.c.mod, ptr, r[0])
|
||||
return util.View(s.c.mod, ptr, r)
|
||||
}
|
||||
|
||||
// Return true if stmt is an empty SQL statement.
|
||||
|
||||
@@ -124,4 +124,46 @@ func TestBackup(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
func() { // Incremental.
|
||||
db, err := sqlite3.Open(backupName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
b, err := db.BackupInit("main", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer b.Close()
|
||||
|
||||
done, err := b.Step(1)
|
||||
if done {
|
||||
t.Error("want false")
|
||||
}
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
n := b.Remaining()
|
||||
if n != 1 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
n = b.PageCount()
|
||||
if n != 2 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
err = b.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ func (t params) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
func (sqliteDB) RunTest(t *testing.T, fn func(params)) {
|
||||
db, err := sql.Open("sqlite3", "file:"+
|
||||
filepath.Join(t.TempDir(), "foo.db")+
|
||||
"?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=synchronous(off)")
|
||||
"?_pragma=busy_timeout(10000)&_pragma=synchronous(off)")
|
||||
if err != nil {
|
||||
t.Fatalf("foo.db open fail: %v", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ package tests
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -22,6 +25,55 @@ func TestConn_Open_dir(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Open_notfound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := sqlite3.OpenFlags("test.db", sqlite3.OPEN_READONLY)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Open_modeof(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "test.db")
|
||||
mode := filepath.Join(dir, "modeof.txt")
|
||||
|
||||
fd, err := os.OpenFile(mode, os.O_CREATE, 0624)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fi, err := fd.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fd.Close()
|
||||
|
||||
db, err := sqlite3.Open("file:" + file + "?modeof=" + mode)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
di, err := os.Stat(file)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
if di.Mode() != fi.Mode() {
|
||||
t.Errorf("got %v, want %v", di.Mode(), fi.Mode())
|
||||
}
|
||||
|
||||
_, err = sqlite3.Open("file:" + file + "?modeof=" + mode + "2")
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Close(t *testing.T) {
|
||||
var conn *sqlite3.Conn
|
||||
conn.Close()
|
||||
@@ -58,6 +110,41 @@ func TestConn_Close_BUSY(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_Pragma(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open("file::memory:?_pragma=busy_timeout(1000)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
got, err := db.Pragma("busy_timeout")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []string{"1000"}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
|
||||
var serr *sqlite3.Error
|
||||
_, err = db.Pragma("+")
|
||||
if err == nil {
|
||||
t.Error("want: error")
|
||||
}
|
||||
if !errors.As(err, &serr) {
|
||||
t.Fatalf("got %T, want sqlite3.Error", err)
|
||||
}
|
||||
if rc := serr.Code(); rc != sqlite3.ERROR {
|
||||
t.Errorf("got %d, want sqlite3.ERROR", rc)
|
||||
}
|
||||
if got := err.Error(); got != `sqlite3: SQL logic error: near "+": syntax error` {
|
||||
t.Error("got message:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConn_SetInterrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -1,24 +1,45 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/wal.db
|
||||
var waldb []byte
|
||||
|
||||
func TestDB_memory(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, ":memory:")
|
||||
}
|
||||
|
||||
func TestDB_file(t *testing.T) {
|
||||
t.Parallel()
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
func TestDB_wal(t *testing.T) {
|
||||
t.Parallel()
|
||||
wal := filepath.Join(t.TempDir(), "test.db")
|
||||
err := os.WriteFile(wal, waldb, 0666)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testDB(t, wal)
|
||||
}
|
||||
|
||||
func TestDB_vfs(t *testing.T) {
|
||||
testDB(t, "file:test.db?vfs=memdb")
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
db, err := sqlite3.Open(name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -27,18 +27,32 @@ func TestDriver(t *testing.T) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.ExecContext(ctx,
|
||||
res, err := conn.ExecContext(ctx,
|
||||
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if changes != 0 {
|
||||
t.Errorf("got %d want 0", changes)
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != 0 {
|
||||
t.Errorf("got %d want 0", changes)
|
||||
}
|
||||
|
||||
res, err := conn.ExecContext(ctx,
|
||||
res, err = conn.ExecContext(ctx,
|
||||
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
changes, err := res.RowsAffected()
|
||||
changes, err = res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
188
tests/func_test.go
Normal file
188
tests/func_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestCreateFunction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg := arg[0]; arg.Int() {
|
||||
case 0:
|
||||
ctx.ResultInt(arg.Int())
|
||||
case 1:
|
||||
ctx.ResultInt64(arg.Int64())
|
||||
case 2:
|
||||
ctx.ResultBool(arg.Bool())
|
||||
case 3:
|
||||
ctx.ResultFloat(arg.Float())
|
||||
case 4:
|
||||
ctx.ResultText(arg.Text())
|
||||
case 5:
|
||||
ctx.ResultBlob(arg.Blob(nil))
|
||||
case 6:
|
||||
ctx.ResultZeroBlob(arg.Int64())
|
||||
case 7:
|
||||
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
|
||||
case 8:
|
||||
ctx.ResultNull()
|
||||
case 9:
|
||||
ctx.ResultError(sqlite3.FULL)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 1 {
|
||||
t.Errorf("got %v, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
|
||||
t.Errorf("got %v, want FLOAT", got)
|
||||
}
|
||||
if got := stmt.ColumnInt64(0); got != 3 {
|
||||
t.Errorf("got %v, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnText(0); got != "4" {
|
||||
t.Errorf("got %s, want 4", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); string(got) != "5" {
|
||||
t.Errorf("got %s, want 5", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
|
||||
t.Errorf("got %v, want BLOB", got)
|
||||
}
|
||||
if got := stmt.ColumnRawBlob(0); len(got) != 6 {
|
||||
t.Errorf("got %v, want 6", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
|
||||
t.Errorf("got %v, want TEXT", got)
|
||||
}
|
||||
if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.Step() {
|
||||
t.Error("want error")
|
||||
}
|
||||
if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
|
||||
t.Errorf("got %v, want sqlite3.FULL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnyCollationNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.AnyCollationNeeded()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := 0
|
||||
ids := []int{0, 2, 1}
|
||||
names := []string{"go", "whatever", "zig"}
|
||||
for ; stmt.Step(); row++ {
|
||||
id := stmt.ColumnInt(0)
|
||||
name := stmt.ColumnText(1)
|
||||
|
||||
if id != ids[row] {
|
||||
t.Errorf("got %d, want %d", id, ids[row])
|
||||
}
|
||||
if name != names[row] {
|
||||
t.Errorf("got %q, want %q", name, names[row])
|
||||
}
|
||||
}
|
||||
if row != 3 {
|
||||
t.Errorf("got %d, want %d", row, len(ids))
|
||||
}
|
||||
|
||||
if err := stmt.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
@@ -21,7 +22,27 @@ func TestParallel(t *testing.T) {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
name := "file:" +
|
||||
filepath.Join(t.TempDir(), "test.db") +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestMemory(t *testing.T) {
|
||||
var iter int
|
||||
if testing.Short() {
|
||||
iter = 1000
|
||||
} else {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
name := "file:/test.db?vfs=memdb" +
|
||||
"&_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(memory)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
@@ -31,8 +52,13 @@ func TestMultiProcess(t *testing.T) {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbname", name)
|
||||
file := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbfile", file)
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
|
||||
out, err := cmd.StdoutPipe()
|
||||
@@ -57,11 +83,16 @@ func TestMultiProcess(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChildProcess(t *testing.T) {
|
||||
name := os.Getenv("TestMultiProcess_dbname")
|
||||
if name == "" || testing.Short() {
|
||||
file := os.Getenv("TestMultiProcess_dbfile")
|
||||
if file == "" || testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
}
|
||||
|
||||
@@ -73,16 +104,6 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA busy_timeout=10000;
|
||||
PRAGMA synchronous=off;
|
||||
PRAGMA locking_mode=normal;
|
||||
PRAGMA journal_mode=truncate;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -103,10 +124,7 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA busy_timeout=10000;
|
||||
PRAGMA locking_mode=normal;
|
||||
`)
|
||||
err = db.Exec(`PRAGMA busy_timeout=10000`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
BIN
tests/testdata/wal.db
vendored
Normal file
BIN
tests/testdata/wal.db
vendored
Normal file
Binary file not shown.
36
tests/vfs_test.go
Normal file
36
tests/vfs_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
"github.com/ncruces/go-sqlite3/vfs/readervfs"
|
||||
)
|
||||
|
||||
func TestMemoryVFS_Open_notfound(t *testing.T) {
|
||||
memdb.Delete("demo.db")
|
||||
|
||||
_, err := sqlite3.Open("file:/demo.db?vfs=memdb&mode=ro")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReaderVFS_Open_notfound(t *testing.T) {
|
||||
readervfs.Delete("demo.db")
|
||||
|
||||
_, err := sqlite3.Open("file:demo.db?vfs=reader")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.CANTOPEN) {
|
||||
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
|
||||
}
|
||||
}
|
||||
125
value.go
Normal file
125
value.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Value is any value that can be stored in a database table.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value.html
|
||||
type Value struct {
|
||||
*sqlite
|
||||
handle uint32
|
||||
}
|
||||
|
||||
// Type returns the initial [Datatype] of the value.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Type() Datatype {
|
||||
r := v.call(v.api.valueType, uint64(v.handle))
|
||||
return Datatype(r)
|
||||
}
|
||||
|
||||
// Bool returns the value as a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are retrieved as integers,
|
||||
// with 0 converted to false and any other value to true.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Bool() bool {
|
||||
if i := v.Int64(); i != 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Int returns the value as an int.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int() int {
|
||||
return int(v.Int64())
|
||||
}
|
||||
|
||||
// Int64 returns the value as an int64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Int64() int64 {
|
||||
r := v.call(v.api.valueInteger, uint64(v.handle))
|
||||
return int64(r)
|
||||
}
|
||||
|
||||
// Float returns the value as a float64.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Float() float64 {
|
||||
r := v.call(v.api.valueFloat, uint64(v.handle))
|
||||
return math.Float64frombits(r)
|
||||
}
|
||||
|
||||
// Time returns the value as a [time.Time].
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Time(format TimeFormat) time.Time {
|
||||
var a any
|
||||
switch v.Type() {
|
||||
case INTEGER:
|
||||
a = v.Int64()
|
||||
case FLOAT:
|
||||
a = v.Float()
|
||||
case TEXT, BLOB:
|
||||
a = v.Text()
|
||||
case NULL:
|
||||
return time.Time{}
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
t, _ := format.Decode(a)
|
||||
return t
|
||||
}
|
||||
|
||||
// Text returns the value as a string.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Text() string {
|
||||
return string(v.RawText())
|
||||
}
|
||||
|
||||
// Blob appends to buf and returns
|
||||
// the value as a []byte.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Blob(buf []byte) []byte {
|
||||
return append(buf, v.RawBlob()...)
|
||||
}
|
||||
|
||||
// RawText returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawText() []byte {
|
||||
r := v.call(v.api.valueText, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
// RawBlob returns the value as a []byte.
|
||||
// The []byte is owned by SQLite and may be invalidated by
|
||||
// subsequent calls to [Value] methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) RawBlob() []byte {
|
||||
r := v.call(v.api.valueBlob, uint64(v.handle))
|
||||
return v.rawBytes(uint32(r))
|
||||
}
|
||||
|
||||
func (v Value) rawBytes(ptr uint32) []byte {
|
||||
if ptr == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
r := v.call(v.api.valueBytes, uint64(v.handle))
|
||||
return util.View(v.mod, ptr, r)
|
||||
}
|
||||
9
vfs/README.md
Normal file
9
vfs/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Go SQLite VFS API
|
||||
|
||||
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
|
||||
|
||||
It replaces the default VFS with a pure Go implementation,
|
||||
that is tested on Linux, macOS and Windows,
|
||||
but which should also work on illumos and the various BSDs.
|
||||
|
||||
It also exposes interfaces that should allow you to implement your own custom VFSes.
|
||||
103
vfs/api.go
Normal file
103
vfs/api.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// Package vfs wraps the C SQLite VFS API.
|
||||
package vfs
|
||||
|
||||
import "net/url"
|
||||
|
||||
// A VFS defines the interface between the SQLite core and the underlying operating system.
|
||||
//
|
||||
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/vfs.html
|
||||
type VFS interface {
|
||||
Open(name string, flags OpenFlag) (File, OpenFlag, error)
|
||||
Delete(name string, syncDir bool) error
|
||||
Access(name string, flags AccessFlag) (bool, error)
|
||||
FullPathname(name string) (string, error)
|
||||
}
|
||||
|
||||
// VFSParams extends VFS with the ability to handle URI parameters
|
||||
// through the OpenParams method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/uri_boolean.html
|
||||
type VFSParams interface {
|
||||
VFS
|
||||
OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error)
|
||||
}
|
||||
|
||||
// A File represents an open file in the OS interface layer.
|
||||
//
|
||||
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.
|
||||
// In particular, sqlite3.BUSY is necessary to correctly implement lock methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/io_methods.html
|
||||
type File interface {
|
||||
Close() error
|
||||
ReadAt(p []byte, off int64) (n int, err error)
|
||||
WriteAt(p []byte, off int64) (n int, err error)
|
||||
Truncate(size int64) error
|
||||
Sync(flags SyncFlag) error
|
||||
Size() (int64, error)
|
||||
Lock(lock LockLevel) error
|
||||
Unlock(lock LockLevel) error
|
||||
CheckReservedLock() (bool, error)
|
||||
SectorSize() int
|
||||
DeviceCharacteristics() DeviceCharacteristic
|
||||
}
|
||||
|
||||
// FileLockState extends File to implement the
|
||||
// SQLITE_FCNTL_LOCKSTATE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type FileLockState interface {
|
||||
File
|
||||
LockState() LockLevel
|
||||
}
|
||||
|
||||
// FileSizeHint extends File to implement the
|
||||
// SQLITE_FCNTL_SIZE_HINT file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type FileSizeHint interface {
|
||||
File
|
||||
SizeHint(size int64) error
|
||||
}
|
||||
|
||||
// FileHasMoved extends File to implement the
|
||||
// SQLITE_FCNTL_HAS_MOVED file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type FileHasMoved interface {
|
||||
File
|
||||
HasMoved() (bool, error)
|
||||
}
|
||||
|
||||
// FilePowersafeOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type FilePowersafeOverwrite interface {
|
||||
File
|
||||
PowersafeOverwrite() bool
|
||||
SetPowersafeOverwrite(bool)
|
||||
}
|
||||
|
||||
// FilePowersafeOverwrite extends File to implement the
|
||||
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type FileCommitPhaseTwo interface {
|
||||
File
|
||||
CommitPhaseTwo() error
|
||||
}
|
||||
|
||||
// FileBatchAtomicWrite extends File to implement the
|
||||
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
|
||||
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type FileBatchAtomicWrite interface {
|
||||
File
|
||||
BeginAtomicWrite() error
|
||||
CommitAtomicWrite() error
|
||||
RollbackAtomicWrite() error
|
||||
}
|
||||
215
vfs/const.go
Normal file
215
vfs/const.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package vfs
|
||||
|
||||
import "github.com/ncruces/go-sqlite3/internal/util"
|
||||
|
||||
const (
|
||||
_MAX_STRING = 512 // Used for short strings: names, error messages…
|
||||
_MAX_PATHNAME = 512
|
||||
_DEFAULT_SECTOR_SIZE = 4096
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/rescode.html
|
||||
type _ErrorCode uint32
|
||||
|
||||
func (e _ErrorCode) Error() string {
|
||||
return util.ErrorCodeString(uint32(e))
|
||||
}
|
||||
|
||||
const (
|
||||
_OK _ErrorCode = util.OK
|
||||
_PERM _ErrorCode = util.PERM
|
||||
_BUSY _ErrorCode = util.BUSY
|
||||
_READONLY _ErrorCode = util.READONLY
|
||||
_IOERR _ErrorCode = util.IOERR
|
||||
_NOTFOUND _ErrorCode = util.NOTFOUND
|
||||
_CANTOPEN _ErrorCode = util.CANTOPEN
|
||||
_IOERR_READ _ErrorCode = util.IOERR_READ
|
||||
_IOERR_SHORT_READ _ErrorCode = util.IOERR_SHORT_READ
|
||||
_IOERR_WRITE _ErrorCode = util.IOERR_WRITE
|
||||
_IOERR_FSYNC _ErrorCode = util.IOERR_FSYNC
|
||||
_IOERR_DIR_FSYNC _ErrorCode = util.IOERR_DIR_FSYNC
|
||||
_IOERR_TRUNCATE _ErrorCode = util.IOERR_TRUNCATE
|
||||
_IOERR_FSTAT _ErrorCode = util.IOERR_FSTAT
|
||||
_IOERR_UNLOCK _ErrorCode = util.IOERR_UNLOCK
|
||||
_IOERR_RDLOCK _ErrorCode = util.IOERR_RDLOCK
|
||||
_IOERR_DELETE _ErrorCode = util.IOERR_DELETE
|
||||
_IOERR_ACCESS _ErrorCode = util.IOERR_ACCESS
|
||||
_IOERR_CHECKRESERVEDLOCK _ErrorCode = util.IOERR_CHECKRESERVEDLOCK
|
||||
_IOERR_LOCK _ErrorCode = util.IOERR_LOCK
|
||||
_IOERR_CLOSE _ErrorCode = util.IOERR_CLOSE
|
||||
_IOERR_SEEK _ErrorCode = util.IOERR_SEEK
|
||||
_IOERR_DELETE_NOENT _ErrorCode = util.IOERR_DELETE_NOENT
|
||||
_IOERR_BEGIN_ATOMIC _ErrorCode = util.IOERR_BEGIN_ATOMIC
|
||||
_IOERR_COMMIT_ATOMIC _ErrorCode = util.IOERR_COMMIT_ATOMIC
|
||||
_IOERR_ROLLBACK_ATOMIC _ErrorCode = util.IOERR_ROLLBACK_ATOMIC
|
||||
_CANTOPEN_FULLPATH _ErrorCode = util.CANTOPEN_FULLPATH
|
||||
_CANTOPEN_ISDIR _ErrorCode = util.CANTOPEN_ISDIR
|
||||
_OK_SYMLINK _ErrorCode = util.OK_SYMLINK
|
||||
)
|
||||
|
||||
// OpenFlag is a flag for the [VFS] Open method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_open_autoproxy.html
|
||||
type OpenFlag uint32
|
||||
|
||||
const (
|
||||
OPEN_READONLY OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_READWRITE OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_CREATE OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_DELETEONCLOSE OpenFlag = 0x00000008 /* VFS only */
|
||||
OPEN_EXCLUSIVE OpenFlag = 0x00000010 /* VFS only */
|
||||
OPEN_AUTOPROXY OpenFlag = 0x00000020 /* VFS only */
|
||||
OPEN_URI OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MEMORY OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_MAIN_DB OpenFlag = 0x00000100 /* VFS only */
|
||||
OPEN_TEMP_DB OpenFlag = 0x00000200 /* VFS only */
|
||||
OPEN_TRANSIENT_DB OpenFlag = 0x00000400 /* VFS only */
|
||||
OPEN_MAIN_JOURNAL OpenFlag = 0x00000800 /* VFS only */
|
||||
OPEN_TEMP_JOURNAL OpenFlag = 0x00001000 /* VFS only */
|
||||
OPEN_SUBJOURNAL OpenFlag = 0x00002000 /* VFS only */
|
||||
OPEN_SUPER_JOURNAL OpenFlag = 0x00004000 /* VFS only */
|
||||
OPEN_NOMUTEX OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_FULLMUTEX OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_SHAREDCACHE OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_PRIVATECACHE OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */
|
||||
OPEN_WAL OpenFlag = 0x00080000 /* VFS only */
|
||||
OPEN_NOFOLLOW OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */
|
||||
)
|
||||
|
||||
// AccessFlag is a flag for the [VFS] Access method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_access_exists.html
|
||||
type AccessFlag uint32
|
||||
|
||||
const (
|
||||
ACCESS_EXISTS AccessFlag = 0
|
||||
ACCESS_READWRITE AccessFlag = 1 /* Used by PRAGMA temp_store_directory */
|
||||
ACCESS_READ AccessFlag = 2 /* Unused */
|
||||
)
|
||||
|
||||
// SyncFlag is a flag for the [File] Sync method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_sync_dataonly.html
|
||||
type SyncFlag uint32
|
||||
|
||||
const (
|
||||
SYNC_NORMAL SyncFlag = 0x00002
|
||||
SYNC_FULL SyncFlag = 0x00003
|
||||
SYNC_DATAONLY SyncFlag = 0x00010
|
||||
)
|
||||
|
||||
// LockLevel is a value used with [File] Lock and Unlock methods.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_lock_exclusive.html
|
||||
type LockLevel uint32
|
||||
|
||||
const (
|
||||
// No locks are held on the database.
|
||||
// The database may be neither read nor written.
|
||||
// Any internally cached data is considered suspect and subject to
|
||||
// verification against the database file before being used.
|
||||
// Other processes can read or write the database as their own locking
|
||||
// states permit.
|
||||
// This is the default state.
|
||||
LOCK_NONE LockLevel = 0 /* xUnlock() only */
|
||||
|
||||
// The database may be read but not written.
|
||||
// Any number of processes can hold SHARED locks at the same time,
|
||||
// hence there can be many simultaneous readers.
|
||||
// But no other thread or process is allowed to write to the database file
|
||||
// while one or more SHARED locks are active.
|
||||
LOCK_SHARED LockLevel = 1 /* xLock() or xUnlock() */
|
||||
|
||||
// A RESERVED lock means that the process is planning on writing to the
|
||||
// database file at some point in the future but that it is currently just
|
||||
// reading from the file.
|
||||
// Only a single RESERVED lock may be active at one time,
|
||||
// though multiple SHARED locks can coexist with a single RESERVED lock.
|
||||
// RESERVED differs from PENDING in that new SHARED locks can be acquired
|
||||
// while there is a RESERVED lock.
|
||||
LOCK_RESERVED LockLevel = 2 /* xLock() only */
|
||||
|
||||
// A PENDING lock means that the process holding the lock wants to write to
|
||||
// the database as soon as possible and is just waiting on all current
|
||||
// SHARED locks to clear so that it can get an EXCLUSIVE lock.
|
||||
// No new SHARED locks are permitted against the database if a PENDING lock
|
||||
// is active, though existing SHARED locks are allowed to continue.
|
||||
LOCK_PENDING LockLevel = 3 /* internal use only */
|
||||
|
||||
// An EXCLUSIVE lock is needed in order to write to the database file.
|
||||
// Only one EXCLUSIVE lock is allowed on the file and no other locks of any
|
||||
// kind are allowed to coexist with an EXCLUSIVE lock.
|
||||
// In order to maximize concurrency, SQLite works to minimize the amount of
|
||||
// time that EXCLUSIVE locks are held.
|
||||
LOCK_EXCLUSIVE LockLevel = 4 /* xLock() only */
|
||||
)
|
||||
|
||||
// DeviceCharacteristic is a flag retuned by the [File] DeviceCharacteristics method.
|
||||
//
|
||||
// https://www.sqlite.org/c3ref/c_iocap_atomic.html
|
||||
type DeviceCharacteristic uint32
|
||||
|
||||
const (
|
||||
IOCAP_ATOMIC DeviceCharacteristic = 0x00000001
|
||||
IOCAP_ATOMIC512 DeviceCharacteristic = 0x00000002
|
||||
IOCAP_ATOMIC1K DeviceCharacteristic = 0x00000004
|
||||
IOCAP_ATOMIC2K DeviceCharacteristic = 0x00000008
|
||||
IOCAP_ATOMIC4K DeviceCharacteristic = 0x00000010
|
||||
IOCAP_ATOMIC8K DeviceCharacteristic = 0x00000020
|
||||
IOCAP_ATOMIC16K DeviceCharacteristic = 0x00000040
|
||||
IOCAP_ATOMIC32K DeviceCharacteristic = 0x00000080
|
||||
IOCAP_ATOMIC64K DeviceCharacteristic = 0x00000100
|
||||
IOCAP_SAFE_APPEND DeviceCharacteristic = 0x00000200
|
||||
IOCAP_SEQUENTIAL DeviceCharacteristic = 0x00000400
|
||||
IOCAP_UNDELETABLE_WHEN_OPEN DeviceCharacteristic = 0x00000800
|
||||
IOCAP_POWERSAFE_OVERWRITE DeviceCharacteristic = 0x00001000
|
||||
IOCAP_IMMUTABLE DeviceCharacteristic = 0x00002000
|
||||
IOCAP_BATCH_ATOMIC DeviceCharacteristic = 0x00004000
|
||||
)
|
||||
|
||||
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
|
||||
type _FcntlOpcode uint32
|
||||
|
||||
const (
|
||||
_FCNTL_LOCKSTATE _FcntlOpcode = 1
|
||||
_FCNTL_GET_LOCKPROXYFILE _FcntlOpcode = 2
|
||||
_FCNTL_SET_LOCKPROXYFILE _FcntlOpcode = 3
|
||||
_FCNTL_LAST_ERRNO _FcntlOpcode = 4
|
||||
_FCNTL_SIZE_HINT _FcntlOpcode = 5
|
||||
_FCNTL_CHUNK_SIZE _FcntlOpcode = 6
|
||||
_FCNTL_FILE_POINTER _FcntlOpcode = 7
|
||||
_FCNTL_SYNC_OMITTED _FcntlOpcode = 8
|
||||
_FCNTL_WIN32_AV_RETRY _FcntlOpcode = 9
|
||||
_FCNTL_PERSIST_WAL _FcntlOpcode = 10
|
||||
_FCNTL_OVERWRITE _FcntlOpcode = 11
|
||||
_FCNTL_VFSNAME _FcntlOpcode = 12
|
||||
_FCNTL_POWERSAFE_OVERWRITE _FcntlOpcode = 13
|
||||
_FCNTL_PRAGMA _FcntlOpcode = 14
|
||||
_FCNTL_BUSYHANDLER _FcntlOpcode = 15
|
||||
_FCNTL_TEMPFILENAME _FcntlOpcode = 16
|
||||
_FCNTL_MMAP_SIZE _FcntlOpcode = 18
|
||||
_FCNTL_TRACE _FcntlOpcode = 19
|
||||
_FCNTL_HAS_MOVED _FcntlOpcode = 20
|
||||
_FCNTL_SYNC _FcntlOpcode = 21
|
||||
_FCNTL_COMMIT_PHASETWO _FcntlOpcode = 22
|
||||
_FCNTL_WIN32_SET_HANDLE _FcntlOpcode = 23
|
||||
_FCNTL_WAL_BLOCK _FcntlOpcode = 24
|
||||
_FCNTL_ZIPVFS _FcntlOpcode = 25
|
||||
_FCNTL_RBU _FcntlOpcode = 26
|
||||
_FCNTL_VFS_POINTER _FcntlOpcode = 27
|
||||
_FCNTL_JOURNAL_POINTER _FcntlOpcode = 28
|
||||
_FCNTL_WIN32_GET_HANDLE _FcntlOpcode = 29
|
||||
_FCNTL_PDB _FcntlOpcode = 30
|
||||
_FCNTL_BEGIN_ATOMIC_WRITE _FcntlOpcode = 31
|
||||
_FCNTL_COMMIT_ATOMIC_WRITE _FcntlOpcode = 32
|
||||
_FCNTL_ROLLBACK_ATOMIC_WRITE _FcntlOpcode = 33
|
||||
_FCNTL_LOCK_TIMEOUT _FcntlOpcode = 34
|
||||
_FCNTL_DATA_VERSION _FcntlOpcode = 35
|
||||
_FCNTL_SIZE_LIMIT _FcntlOpcode = 36
|
||||
_FCNTL_CKPT_DONE _FcntlOpcode = 37
|
||||
_FCNTL_RESERVE_BYTES _FcntlOpcode = 38
|
||||
_FCNTL_CKPT_START _FcntlOpcode = 39
|
||||
_FCNTL_EXTERNAL_READER _FcntlOpcode = 40
|
||||
_FCNTL_CKSM_FILE _FcntlOpcode = 41
|
||||
_FCNTL_RESET_CACHE _FcntlOpcode = 42
|
||||
)
|
||||
201
vfs/file.go
Normal file
201
vfs/file.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type vfsOS struct{}
|
||||
|
||||
func (vfsOS) FullPathname(path string) (string, error) {
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fi, err := os.Lstat(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return path, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if fi.Mode()&fs.ModeSymlink != 0 {
|
||||
err = _OK_SYMLINK
|
||||
}
|
||||
return path, err
|
||||
}
|
||||
|
||||
func (vfsOS) Delete(path string, syncDir bool) error {
|
||||
err := os.Remove(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return _IOERR_DELETE_NOENT
|
||||
}
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS != "windows" && syncDir {
|
||||
f, err := os.Open(filepath.Dir(path))
|
||||
if err != nil {
|
||||
return _OK
|
||||
}
|
||||
defer f.Close()
|
||||
err = osSync(f, false, false)
|
||||
if err != nil {
|
||||
return _IOERR_DIR_FSYNC
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vfsOS) Access(name string, flags AccessFlag) (bool, error) {
|
||||
err := osAccess(name, flags)
|
||||
if flags == ACCESS_EXISTS {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return false, nil
|
||||
}
|
||||
} else {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
|
||||
return vfsOS{}.OpenParams(name, flags, nil)
|
||||
}
|
||||
|
||||
func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error) {
|
||||
var oflags int
|
||||
if flags&OPEN_EXCLUSIVE != 0 {
|
||||
oflags |= os.O_EXCL
|
||||
}
|
||||
if flags&OPEN_CREATE != 0 {
|
||||
oflags |= os.O_CREATE
|
||||
}
|
||||
if flags&OPEN_READONLY != 0 {
|
||||
oflags |= os.O_RDONLY
|
||||
}
|
||||
if flags&OPEN_READWRITE != 0 {
|
||||
oflags |= os.O_RDWR
|
||||
}
|
||||
|
||||
var err error
|
||||
var f *os.File
|
||||
if name == "" {
|
||||
f, err = os.CreateTemp("", "*.db")
|
||||
} else {
|
||||
f, err = osOpenFile(name, oflags, 0666)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EISDIR) {
|
||||
return nil, flags, _CANTOPEN_ISDIR
|
||||
}
|
||||
return nil, flags, err
|
||||
}
|
||||
|
||||
if modeof := params.Get("modeof"); modeof != "" {
|
||||
if err = osSetMode(f, modeof); err != nil {
|
||||
f.Close()
|
||||
return nil, flags, _IOERR_FSTAT
|
||||
}
|
||||
}
|
||||
if flags&OPEN_DELETEONCLOSE != 0 {
|
||||
os.Remove(f.Name())
|
||||
}
|
||||
|
||||
file := vfsFile{
|
||||
File: f,
|
||||
psow: true,
|
||||
readOnly: flags&OPEN_READONLY != 0,
|
||||
syncDir: runtime.GOOS != "windows" &&
|
||||
flags&(OPEN_CREATE) != 0 &&
|
||||
flags&(OPEN_MAIN_JOURNAL|OPEN_SUPER_JOURNAL|OPEN_WAL) != 0,
|
||||
}
|
||||
return &file, flags, nil
|
||||
}
|
||||
|
||||
type vfsFile struct {
|
||||
*os.File
|
||||
lockTimeout time.Duration
|
||||
lock LockLevel
|
||||
psow bool
|
||||
syncDir bool
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ FileLockState = &vfsFile{}
|
||||
_ FileHasMoved = &vfsFile{}
|
||||
_ FileSizeHint = &vfsFile{}
|
||||
_ FilePowersafeOverwrite = &vfsFile{}
|
||||
)
|
||||
|
||||
func (f *vfsFile) Sync(flags SyncFlag) error {
|
||||
dataonly := (flags & SYNC_DATAONLY) != 0
|
||||
fullsync := (flags & 0x0f) == SYNC_FULL
|
||||
|
||||
err := osSync(f.File, fullsync, dataonly)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS != "windows" && f.syncDir {
|
||||
f.syncDir = false
|
||||
d, err := os.Open(filepath.Dir(f.File.Name()))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer d.Close()
|
||||
err = osSync(d, false, false)
|
||||
if err != nil {
|
||||
return _IOERR_DIR_FSYNC
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *vfsFile) Size() (int64, error) {
|
||||
return f.Seek(0, io.SeekEnd)
|
||||
}
|
||||
|
||||
func (*vfsFile) SectorSize() int {
|
||||
return _DEFAULT_SECTOR_SIZE
|
||||
}
|
||||
|
||||
func (f *vfsFile) DeviceCharacteristics() DeviceCharacteristic {
|
||||
if f.psow {
|
||||
return IOCAP_POWERSAFE_OVERWRITE
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f *vfsFile) SizeHint(size int64) error {
|
||||
return osAllocate(f.File, size)
|
||||
}
|
||||
|
||||
func (f *vfsFile) HasMoved() (bool, error) {
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
pi, err := os.Stat(f.Name())
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return !os.SameFile(fi, pi), nil
|
||||
}
|
||||
|
||||
func (f *vfsFile) LockState() LockLevel { return f.lock }
|
||||
func (f *vfsFile) PowersafeOverwrite() bool { return f.psow }
|
||||
func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow }
|
||||
150
vfs/lock.go
Normal file
150
vfs/lock.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package vfs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
const (
|
||||
_PENDING_BYTE = 0x40000000
|
||||
_RESERVED_BYTE = (_PENDING_BYTE + 1)
|
||||
_SHARED_FIRST = (_PENDING_BYTE + 2)
|
||||
_SHARED_SIZE = 510
|
||||
)
|
||||
|
||||
func (f *vfsFile) Lock(lock LockLevel) error {
|
||||
// Argument check. SQLite never explicitly requests a pending lock.
|
||||
if lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
switch {
|
||||
case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE:
|
||||
// Connection state check.
|
||||
panic(util.AssertErr())
|
||||
case f.lock == LOCK_NONE && lock > LOCK_SHARED:
|
||||
// We never move from unlocked to anything higher than a shared lock.
|
||||
panic(util.AssertErr())
|
||||
case f.lock != LOCK_SHARED && lock == LOCK_RESERVED:
|
||||
// A shared lock is always held when a reserved lock is requested.
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
// If we already have an equal or more restrictive lock, do nothing.
|
||||
if f.lock >= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do not allow any kind of write-lock on a read-only database.
|
||||
if f.readOnly && lock >= LOCK_RESERVED {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
|
||||
switch lock {
|
||||
case LOCK_SHARED:
|
||||
// Must be unlocked to get SHARED.
|
||||
if f.lock != LOCK_NONE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_SHARED
|
||||
return nil
|
||||
|
||||
case LOCK_RESERVED:
|
||||
// Must be SHARED to get RESERVED.
|
||||
if f.lock != LOCK_SHARED {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_RESERVED
|
||||
return nil
|
||||
|
||||
case LOCK_EXCLUSIVE:
|
||||
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
|
||||
if f.lock <= LOCK_NONE || f.lock >= LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
|
||||
if f.lock < LOCK_PENDING {
|
||||
if rc := osGetPendingLock(f.File); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_PENDING
|
||||
}
|
||||
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_EXCLUSIVE
|
||||
return nil
|
||||
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (f *vfsFile) Unlock(lock LockLevel) error {
|
||||
// Argument check.
|
||||
if lock != LOCK_NONE && lock != LOCK_SHARED {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
// Connection state check.
|
||||
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
// If we don't have a more restrictive lock, do nothing.
|
||||
if f.lock <= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch lock {
|
||||
case LOCK_SHARED:
|
||||
if rc := osDowngradeLock(f.File, f.lock); rc != _OK {
|
||||
return rc
|
||||
}
|
||||
f.lock = LOCK_SHARED
|
||||
return nil
|
||||
|
||||
case LOCK_NONE:
|
||||
rc := osReleaseLock(f.File, f.lock)
|
||||
f.lock = LOCK_NONE
|
||||
return rc
|
||||
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
}
|
||||
|
||||
func (f *vfsFile) CheckReservedLock() (bool, error) {
|
||||
// Connection state check.
|
||||
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
if f.lock >= LOCK_RESERVED {
|
||||
return true, nil
|
||||
}
|
||||
return osCheckReservedLock(f.File)
|
||||
}
|
||||
|
||||
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
// Acquire the RESERVED lock.
|
||||
return osWriteLock(file, _RESERVED_BYTE, 1, timeout)
|
||||
}
|
||||
|
||||
func osGetPendingLock(file *os.File) _ErrorCode {
|
||||
// Acquire the PENDING lock.
|
||||
return osWriteLock(file, _PENDING_BYTE, 1, 0)
|
||||
}
|
||||
|
||||
func osCheckReservedLock(file *os.File) (bool, _ErrorCode) {
|
||||
// Test the RESERVED lock.
|
||||
return osCheckLock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/tetratelabs/wazero/experimental/wazerotest"
|
||||
)
|
||||
|
||||
func Test_vfsLock(t *testing.T) {
|
||||
@@ -39,12 +40,12 @@ func Test_vfsLock(t *testing.T) {
|
||||
pFile2 = 16
|
||||
pOutput = 32
|
||||
)
|
||||
mod := util.NewMockModule(128)
|
||||
ctx, vfs := Context(context.TODO())
|
||||
defer vfs.Close()
|
||||
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
|
||||
ctx, closer := util.NewContext(context.TODO())
|
||||
defer closer.Close()
|
||||
|
||||
openVFSFile(ctx, mod, pFile1, file1)
|
||||
openVFSFile(ctx, mod, pFile2, file2)
|
||||
vfsFileRegister(ctx, mod, pFile1, &vfsFile{File: file1})
|
||||
vfsFileRegister(ctx, mod, pFile2, &vfsFile{File: file2})
|
||||
|
||||
rc := vfsCheckReservedLock(ctx, mod, pFile1, pOutput)
|
||||
if rc != _OK {
|
||||
@@ -60,8 +61,15 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got != 0 {
|
||||
t.Error("file was locked")
|
||||
}
|
||||
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsLock(ctx, mod, pFile2, _LOCK_SHARED)
|
||||
rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
@@ -80,12 +88,19 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got != 0 {
|
||||
t.Error("file was locked")
|
||||
}
|
||||
|
||||
rc = vfsLock(ctx, mod, pFile2, _LOCK_RESERVED)
|
||||
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
rc = vfsLock(ctx, mod, pFile2, _LOCK_SHARED)
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsLock(ctx, mod, pFile2, LOCK_RESERVED)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
@@ -104,8 +119,15 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got == 0 {
|
||||
t.Error("file wasn't locked")
|
||||
}
|
||||
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_RESERVED) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsLock(ctx, mod, pFile2, _LOCK_EXCLUSIVE)
|
||||
rc = vfsLock(ctx, mod, pFile2, LOCK_EXCLUSIVE)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
@@ -124,8 +146,15 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got == 0 {
|
||||
t.Error("file wasn't locked")
|
||||
}
|
||||
rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_EXCLUSIVE) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsLock(ctx, mod, pFile1, _LOCK_SHARED)
|
||||
rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED)
|
||||
if rc == _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
@@ -144,8 +173,15 @@ func Test_vfsLock(t *testing.T) {
|
||||
if got := util.ReadUint32(mod, pOutput); got == 0 {
|
||||
t.Error("file wasn't locked")
|
||||
}
|
||||
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsUnlock(ctx, mod, pFile2, _LOCK_SHARED)
|
||||
rc = vfsUnlock(ctx, mod, pFile2, LOCK_SHARED)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
@@ -165,7 +201,19 @@ func Test_vfsLock(t *testing.T) {
|
||||
t.Error("file was locked")
|
||||
}
|
||||
|
||||
rc = vfsLock(ctx, mod, pFile1, _LOCK_SHARED)
|
||||
rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) {
|
||||
t.Error("invalid lock state", got)
|
||||
}
|
||||
|
||||
rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1)
|
||||
if rc != _OK {
|
||||
t.Fatal("returned", rc)
|
||||
}
|
||||
9
vfs/memdb/README.md
Normal file
9
vfs/memdb/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Go `"memdb"` SQLite VFS
|
||||
|
||||
This package implements the [`"memdb"`](https://www.sqlite.org/src/file/src/memdb.c)
|
||||
SQLite VFS in pure Go.
|
||||
|
||||
It has some benefits over the C version:
|
||||
- the memory backing the database needs not be contiguous,
|
||||
- the database can grow/shrink incrementally without copying,
|
||||
- reader-writer concurrency is slightly improved.
|
||||
59
vfs/memdb/api.go
Normal file
59
vfs/memdb/api.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package memdb implements the "memdb" SQLite VFS.
|
||||
//
|
||||
// The "memdb" [vfs.VFS] allows the same in-memory database to be shared
|
||||
// among multiple database connections in the same process,
|
||||
// as long as the database name begins with "/".
|
||||
//
|
||||
// Importing package memdb registers the VFS.
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
package memdb
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
)
|
||||
|
||||
func init() {
|
||||
vfs.Register("memdb", memVFS{})
|
||||
}
|
||||
|
||||
var (
|
||||
memoryMtx sync.Mutex
|
||||
// +checklocks:memoryMtx
|
||||
memoryDBs = map[string]*memDB{}
|
||||
)
|
||||
|
||||
// Create creates a shared memory database,
|
||||
// using data as its initial contents.
|
||||
// The new database takes ownership of data,
|
||||
// and the caller should not use data after this call.
|
||||
func Create(name string, data []byte) {
|
||||
memoryMtx.Lock()
|
||||
defer memoryMtx.Unlock()
|
||||
|
||||
db := new(memDB)
|
||||
db.size = int64(len(data))
|
||||
|
||||
sectors := divRoundUp(db.size, sectorSize)
|
||||
db.data = make([]*[sectorSize]byte, sectors)
|
||||
for i := range db.data {
|
||||
sector := data[i*sectorSize:]
|
||||
if len(sector) >= sectorSize {
|
||||
db.data[i] = (*[sectorSize]byte)(sector)
|
||||
} else {
|
||||
db.data[i] = new([sectorSize]byte)
|
||||
copy((*db.data[i])[:], sector)
|
||||
}
|
||||
}
|
||||
|
||||
memoryDBs[name] = db
|
||||
}
|
||||
|
||||
// Delete deletes a shared memory database.
|
||||
func Delete(name string) {
|
||||
memoryMtx.Lock()
|
||||
defer memoryMtx.Unlock()
|
||||
delete(memoryDBs, name)
|
||||
}
|
||||
51
vfs/memdb/example_test.go
Normal file
51
vfs/memdb/example_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package memdb_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "embed"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/test.db
|
||||
var testDB []byte
|
||||
|
||||
func Example() {
|
||||
memdb.Create("test.db", testDB)
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:/test.db?vfs=memdb")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s %s\n", id, name)
|
||||
}
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 2 whatever
|
||||
// 3 rust
|
||||
}
|
||||
294
vfs/memdb/memdb.go
Normal file
294
vfs/memdb/memdb.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package memdb
|
||||
|
||||
import (
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
)
|
||||
|
||||
type memVFS struct{}
|
||||
|
||||
func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
|
||||
// Allowed file types:
|
||||
// - databases, which only do page aligned reads/writes;
|
||||
// - temp journals, used by the sorter, which does the same.
|
||||
const types = vfs.OPEN_MAIN_DB |
|
||||
vfs.OPEN_TRANSIENT_DB |
|
||||
vfs.OPEN_TEMP_DB |
|
||||
vfs.OPEN_TEMP_JOURNAL
|
||||
if flags&types == 0 {
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
|
||||
var db *memDB
|
||||
|
||||
shared := strings.HasPrefix(name, "/")
|
||||
if shared {
|
||||
memoryMtx.Lock()
|
||||
defer memoryMtx.Unlock()
|
||||
db = memoryDBs[name[1:]]
|
||||
}
|
||||
if db == nil {
|
||||
if flags&vfs.OPEN_CREATE == 0 {
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
db = new(memDB)
|
||||
}
|
||||
if shared {
|
||||
memoryDBs[name[1:]] = db // +checklocksignore: lock is held
|
||||
}
|
||||
|
||||
return &memFile{
|
||||
memDB: db,
|
||||
readOnly: flags&vfs.OPEN_READONLY != 0,
|
||||
}, flags | vfs.OPEN_MEMORY, nil
|
||||
}
|
||||
|
||||
func (memVFS) Delete(name string, dirSync bool) error {
|
||||
return sqlite3.IOERR_DELETE
|
||||
}
|
||||
|
||||
func (memVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (memVFS) FullPathname(name string) (string, error) {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// Must be a multiple of 64K (the largest page size).
|
||||
const sectorSize = 65536
|
||||
|
||||
type memDB struct {
|
||||
// +checklocks:lockMtx
|
||||
pending *memFile
|
||||
// +checklocks:lockMtx
|
||||
reserved *memFile
|
||||
|
||||
// +checklocks:dataMtx
|
||||
data []*[sectorSize]byte
|
||||
|
||||
// +checklocks:dataMtx
|
||||
size int64
|
||||
|
||||
// +checklocks:lockMtx
|
||||
shared int
|
||||
|
||||
lockMtx sync.Mutex
|
||||
dataMtx sync.RWMutex
|
||||
}
|
||||
|
||||
type memFile struct {
|
||||
*memDB
|
||||
lock vfs.LockLevel
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ vfs.FileLockState = &memFile{}
|
||||
_ vfs.FileSizeHint = &memFile{}
|
||||
)
|
||||
|
||||
func (m *memFile) Close() error {
|
||||
return m.Unlock(vfs.LOCK_NONE)
|
||||
}
|
||||
|
||||
func (m *memFile) ReadAt(b []byte, off int64) (n int, err error) {
|
||||
m.dataMtx.RLock()
|
||||
defer m.dataMtx.RUnlock()
|
||||
|
||||
if off >= m.size {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
base := off / sectorSize
|
||||
rest := off % sectorSize
|
||||
have := int64(sectorSize)
|
||||
if base == int64(len(m.data))-1 {
|
||||
have = modRoundUp(m.size, sectorSize)
|
||||
}
|
||||
n = copy(b, (*m.data[base])[rest:have])
|
||||
if n < len(b) {
|
||||
// Assume reads are page aligned.
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
|
||||
base := off / sectorSize
|
||||
rest := off % sectorSize
|
||||
for base >= int64(len(m.data)) {
|
||||
m.data = append(m.data, new([sectorSize]byte))
|
||||
}
|
||||
n = copy((*m.data[base])[rest:], b)
|
||||
if n < len(b) {
|
||||
// Assume writes are page aligned.
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
if size := off + int64(len(b)); size > m.size {
|
||||
m.size = size
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *memFile) Truncate(size int64) error {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
return m.truncate(size)
|
||||
}
|
||||
|
||||
// +checklocks:m.dataMtx
|
||||
func (m *memFile) truncate(size int64) error {
|
||||
if size < m.size {
|
||||
base := size / sectorSize
|
||||
rest := size % sectorSize
|
||||
if rest != 0 {
|
||||
clear((*m.data[base])[rest:])
|
||||
}
|
||||
}
|
||||
sectors := divRoundUp(size, sectorSize)
|
||||
for sectors > int64(len(m.data)) {
|
||||
m.data = append(m.data, new([sectorSize]byte))
|
||||
}
|
||||
clear(m.data[sectors:])
|
||||
m.data = m.data[:sectors]
|
||||
m.size = size
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*memFile) Sync(flag vfs.SyncFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memFile) Size() (int64, error) {
|
||||
m.dataMtx.RLock()
|
||||
defer m.dataMtx.RUnlock()
|
||||
return m.size, nil
|
||||
}
|
||||
|
||||
func (m *memFile) Lock(lock vfs.LockLevel) error {
|
||||
if m.lock >= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.readOnly && lock >= vfs.LOCK_RESERVED {
|
||||
return sqlite3.IOERR_LOCK
|
||||
}
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
|
||||
switch lock {
|
||||
case vfs.LOCK_SHARED:
|
||||
if m.pending != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.shared++
|
||||
|
||||
case vfs.LOCK_RESERVED:
|
||||
if m.reserved != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.reserved = m
|
||||
|
||||
case vfs.LOCK_EXCLUSIVE:
|
||||
if m.lock < vfs.LOCK_PENDING {
|
||||
if m.pending != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lock = vfs.LOCK_PENDING
|
||||
m.pending = m
|
||||
}
|
||||
|
||||
for start := time.Now(); m.shared > 1; {
|
||||
if time.Since(start) > time.Millisecond {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
m.lock = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memFile) Unlock(lock vfs.LockLevel) error {
|
||||
if m.lock <= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
|
||||
if m.pending == m {
|
||||
m.pending = nil
|
||||
}
|
||||
if m.reserved == m {
|
||||
m.reserved = nil
|
||||
}
|
||||
if lock < vfs.LOCK_SHARED {
|
||||
m.shared--
|
||||
}
|
||||
m.lock = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memFile) CheckReservedLock() (bool, error) {
|
||||
if m.lock >= vfs.LOCK_RESERVED {
|
||||
return true, nil
|
||||
}
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
return m.reserved != nil, nil
|
||||
}
|
||||
|
||||
func (*memFile) SectorSize() int {
|
||||
return sectorSize
|
||||
}
|
||||
|
||||
func (*memFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
|
||||
return vfs.IOCAP_ATOMIC |
|
||||
vfs.IOCAP_SEQUENTIAL |
|
||||
vfs.IOCAP_SAFE_APPEND |
|
||||
vfs.IOCAP_POWERSAFE_OVERWRITE
|
||||
}
|
||||
|
||||
func (m *memFile) SizeHint(size int64) error {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
if size > m.size {
|
||||
return m.truncate(size)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memFile) LockState() vfs.LockLevel {
|
||||
return m.lock
|
||||
}
|
||||
|
||||
func divRoundUp(a, b int64) int64 {
|
||||
return (a + b - 1) / b
|
||||
}
|
||||
|
||||
func modRoundUp(a, b int64) int64 {
|
||||
return b - (b-a%b)%b
|
||||
}
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
BIN
vfs/memdb/testdata/test.db
vendored
Normal file
BIN
vfs/memdb/testdata/test.db
vendored
Normal file
Binary file not shown.
@@ -47,7 +47,7 @@ func osAllocate(file *os.File, size int64) error {
|
||||
Length: size,
|
||||
}
|
||||
|
||||
// Try to get a continous chunk of disk space.
|
||||
// Try to get a continuous chunk of disk space.
|
||||
err = unix.FcntlFstore(file.Fd(), unix.F_PREALLOCATE, &store)
|
||||
if err != nil {
|
||||
// OK, perhaps we are too fragmented, allocate non-continuous.
|
||||
@@ -5,6 +5,7 @@ package vfs
|
||||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -14,21 +15,33 @@ func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
|
||||
func osAccess(path string, flags _AccessFlag) error {
|
||||
func osAccess(path string, flags AccessFlag) error {
|
||||
var access uint32 // unix.F_OK
|
||||
switch flags {
|
||||
case _ACCESS_READWRITE:
|
||||
case ACCESS_READWRITE:
|
||||
access = unix.R_OK | unix.W_OK
|
||||
case _ACCESS_READ:
|
||||
case ACCESS_READ:
|
||||
access = unix.R_OK
|
||||
}
|
||||
return unix.Access(path, access)
|
||||
}
|
||||
|
||||
func osSetMode(file *os.File, modeof string) error {
|
||||
fi, err := os.Stat(modeof)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.Chmod(fi.Mode())
|
||||
if sys, ok := fi.Sys().(*syscall.Stat_t); ok {
|
||||
file.Chown(int(sys.Uid), int(sys.Gid))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
// Test the PENDING lock before acquiring a new SHARED lock.
|
||||
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {
|
||||
return _ErrorCode(_BUSY)
|
||||
return _BUSY
|
||||
}
|
||||
// Acquire the SHARED lock.
|
||||
return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
|
||||
@@ -43,8 +56,8 @@ func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout)
|
||||
}
|
||||
|
||||
func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
|
||||
if state >= _LOCK_EXCLUSIVE {
|
||||
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
if state >= LOCK_EXCLUSIVE {
|
||||
// Downgrade to a SHARED lock.
|
||||
if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK {
|
||||
// In theory, the downgrade to a SHARED cannot fail because another
|
||||
@@ -59,7 +72,7 @@ func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
|
||||
return osUnlock(file, _PENDING_BYTE, 2)
|
||||
}
|
||||
|
||||
func osReleaseLock(file *os.File, _ _LockLevel) _ErrorCode {
|
||||
func osReleaseLock(file *os.File, _ LockLevel) _ErrorCode {
|
||||
// Release all locks.
|
||||
return osUnlock(file, 0, 0)
|
||||
}
|
||||
@@ -78,9 +91,9 @@ func osLockErrorCode(err error, def _ErrorCode) _ErrorCode {
|
||||
unix.ENOLCK,
|
||||
unix.EDEADLK,
|
||||
unix.ETIMEDOUT:
|
||||
return _ErrorCode(_BUSY)
|
||||
return _BUSY
|
||||
case unix.EPERM:
|
||||
return _ErrorCode(_PERM)
|
||||
return _PERM
|
||||
}
|
||||
}
|
||||
return def
|
||||
@@ -25,17 +25,17 @@ func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
|
||||
return os.NewFile(uintptr(r), name), nil
|
||||
}
|
||||
|
||||
func osAccess(path string, flags _AccessFlag) error {
|
||||
func osAccess(path string, flags AccessFlag) error {
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if flags == _ACCESS_EXISTS {
|
||||
if flags == ACCESS_EXISTS {
|
||||
return nil
|
||||
}
|
||||
|
||||
var want fs.FileMode = windows.S_IRUSR
|
||||
if flags == _ACCESS_READWRITE {
|
||||
if flags == ACCESS_READWRITE {
|
||||
want |= windows.S_IWUSR
|
||||
}
|
||||
if fi.IsDir() {
|
||||
@@ -47,6 +47,15 @@ func osAccess(path string, flags _AccessFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func osSetMode(file *os.File, modeof string) error {
|
||||
fi, err := os.Stat(modeof)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
file.Chmod(fi.Mode())
|
||||
return nil
|
||||
}
|
||||
|
||||
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
|
||||
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)
|
||||
@@ -79,8 +88,8 @@ func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode {
|
||||
return rc
|
||||
}
|
||||
|
||||
func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
|
||||
if state >= _LOCK_EXCLUSIVE {
|
||||
func osDowngradeLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
if state >= LOCK_EXCLUSIVE {
|
||||
// Release the SHARED lock.
|
||||
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
|
||||
|
||||
@@ -93,24 +102,24 @@ func osDowngradeLock(file *os.File, state _LockLevel) _ErrorCode {
|
||||
}
|
||||
|
||||
// Release the PENDING and RESERVED locks.
|
||||
if state >= _LOCK_RESERVED {
|
||||
if state >= LOCK_RESERVED {
|
||||
osUnlock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
if state >= _LOCK_PENDING {
|
||||
if state >= LOCK_PENDING {
|
||||
osUnlock(file, _PENDING_BYTE, 1)
|
||||
}
|
||||
return _OK
|
||||
}
|
||||
|
||||
func osReleaseLock(file *os.File, state _LockLevel) _ErrorCode {
|
||||
func osReleaseLock(file *os.File, state LockLevel) _ErrorCode {
|
||||
// Release all locks.
|
||||
if state >= _LOCK_RESERVED {
|
||||
if state >= LOCK_RESERVED {
|
||||
osUnlock(file, _RESERVED_BYTE, 1)
|
||||
}
|
||||
if state >= _LOCK_SHARED {
|
||||
if state >= LOCK_SHARED {
|
||||
osUnlock(file, _SHARED_FIRST, _SHARED_SIZE)
|
||||
}
|
||||
if state >= _LOCK_PENDING {
|
||||
if state >= LOCK_PENDING {
|
||||
osUnlock(file, _PENDING_BYTE, 1)
|
||||
}
|
||||
return _OK
|
||||
5
vfs/readervfs/README.md
Normal file
5
vfs/readervfs/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Go `"reader"` SQLite VFS
|
||||
|
||||
This package implements a `"reader"` SQLite VFS
|
||||
that allows accessing any [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
|
||||
as an immutable SQLite database.
|
||||
84
vfs/readervfs/api.go
Normal file
84
vfs/readervfs/api.go
Normal file
@@ -0,0 +1,84 @@
|
||||
// Package readervfs implements an SQLite VFS for immutable databases.
|
||||
//
|
||||
// The "reader" [vfs.VFS] permits accessing any [io.ReaderAt]
|
||||
// as an immutable SQLite database.
|
||||
//
|
||||
// Importing package readervfs registers the VFS.
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/vfs/readervfs"
|
||||
package readervfs
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"sync"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
)
|
||||
|
||||
func init() {
|
||||
vfs.Register("reader", readerVFS{})
|
||||
}
|
||||
|
||||
var (
|
||||
readerMtx sync.RWMutex
|
||||
// +checklocks:readerMtx
|
||||
readerDBs = map[string]SizeReaderAt{}
|
||||
)
|
||||
|
||||
// Create creates an immutable database from reader.
|
||||
// The caller should ensure that data from reader does not mutate,
|
||||
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
|
||||
func Create(name string, reader SizeReaderAt) {
|
||||
readerMtx.Lock()
|
||||
defer readerMtx.Unlock()
|
||||
readerDBs[name] = reader
|
||||
}
|
||||
|
||||
// Delete deletes a shared memory database.
|
||||
func Delete(name string) {
|
||||
readerMtx.Lock()
|
||||
defer readerMtx.Unlock()
|
||||
delete(readerDBs, name)
|
||||
}
|
||||
|
||||
// A SizeReaderAt is a ReaderAt with a Size method.
|
||||
// Use [NewSizeReaderAt] to adapt different Size interfaces.
|
||||
type SizeReaderAt interface {
|
||||
Size() (int64, error)
|
||||
io.ReaderAt
|
||||
}
|
||||
|
||||
// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt
|
||||
// that implements one of:
|
||||
// - Size() (int64, error)
|
||||
// - Size() int64
|
||||
// - Len() int
|
||||
// - Stat() (fs.FileInfo, error)
|
||||
// - Seek(offset int64, whence int) (int64, error)
|
||||
func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt {
|
||||
return sizer{r}
|
||||
}
|
||||
|
||||
type sizer struct{ io.ReaderAt }
|
||||
|
||||
func (s sizer) Size() (int64, error) {
|
||||
switch s := s.ReaderAt.(type) {
|
||||
case interface{ Size() (int64, error) }:
|
||||
return s.Size()
|
||||
case interface{ Size() int64 }:
|
||||
return s.Size(), nil
|
||||
case interface{ Len() int }:
|
||||
return int64(s.Len()), nil
|
||||
case interface{ Stat() (fs.FileInfo, error) }:
|
||||
fi, err := s.Stat()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return fi.Size(), nil
|
||||
case io.Seeker:
|
||||
return s.Seek(0, io.SeekEnd)
|
||||
}
|
||||
return 0, sqlite3.IOERR_SEEK
|
||||
}
|
||||
95
vfs/readervfs/example_test.go
Normal file
95
vfs/readervfs/example_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package readervfs_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "embed"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/vfs/readervfs"
|
||||
"github.com/psanford/httpreadat"
|
||||
)
|
||||
|
||||
//go:embed testdata/test.db
|
||||
var testDB []byte
|
||||
|
||||
func Example_http() {
|
||||
readervfs.Create("demo.db", httpreadat.New("https://www.sanford.io/demo.db"))
|
||||
defer readervfs.Delete("demo.db")
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:demo.db?vfs=reader")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
magname := map[int]string{
|
||||
3: "thousand",
|
||||
6: "million",
|
||||
9: "billion",
|
||||
}
|
||||
rows, err := db.Query(`
|
||||
SELECT period, data_value, magntude, units FROM csv
|
||||
WHERE period > '2010'
|
||||
LIMIT 10`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var period, units string
|
||||
var value int64
|
||||
var mag int
|
||||
err = rows.Scan(&period, &value, &mag, &units)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s: %d %s %s\n", period, value, magname[mag], units)
|
||||
}
|
||||
// Output:
|
||||
// 2010.03: 17463 million Dollars
|
||||
// 2010.06: 17260 million Dollars
|
||||
// 2010.09: 15419 million Dollars
|
||||
// 2010.12: 17088 million Dollars
|
||||
// 2011.03: 18516 million Dollars
|
||||
// 2011.06: 18835 million Dollars
|
||||
// 2011.09: 16390 million Dollars
|
||||
// 2011.12: 18748 million Dollars
|
||||
// 2012.03: 18477 million Dollars
|
||||
// 2012.06: 18270 million Dollars
|
||||
}
|
||||
|
||||
func Example_embed() {
|
||||
readervfs.Create("test.db", readervfs.NewSizeReaderAt(bytes.NewReader(testDB)))
|
||||
defer readervfs.Delete("test.db")
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:test.db?vfs=reader")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s %s\n", id, name)
|
||||
}
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 2 whatever
|
||||
}
|
||||
74
vfs/readervfs/reader.go
Normal file
74
vfs/readervfs/reader.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package readervfs
|
||||
|
||||
import (
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/vfs"
|
||||
)
|
||||
|
||||
type readerVFS struct{}
|
||||
|
||||
// Open implements the [vfs.VFS] interface.
|
||||
func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
|
||||
if flags&vfs.OPEN_MAIN_DB == 0 {
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
readerMtx.RLock()
|
||||
defer readerMtx.RUnlock()
|
||||
if ra, ok := readerDBs[name]; ok {
|
||||
return readerFile{ra}, flags | vfs.OPEN_READONLY, nil
|
||||
}
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
|
||||
// Delete implements the [vfs.VFS] interface.
|
||||
func (readerVFS) Delete(name string, dirSync bool) error {
|
||||
return sqlite3.IOERR_DELETE
|
||||
}
|
||||
|
||||
// Access implements the [vfs.VFS] interface.
|
||||
func (readerVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// FullPathname implements the [vfs.VFS] interface.
|
||||
func (readerVFS) FullPathname(name string) (string, error) {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
type readerFile struct{ SizeReaderAt }
|
||||
|
||||
func (readerFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (readerFile) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
return 0, sqlite3.READONLY
|
||||
}
|
||||
|
||||
func (readerFile) Truncate(size int64) error {
|
||||
return sqlite3.READONLY
|
||||
}
|
||||
|
||||
func (readerFile) Sync(flag vfs.SyncFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (readerFile) Lock(lock vfs.LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (readerFile) Unlock(lock vfs.LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (readerFile) CheckReservedLock() (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (readerFile) SectorSize() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (readerFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
|
||||
return vfs.IOCAP_IMMUTABLE
|
||||
}
|
||||
87
vfs/readervfs/reader_test.go
Normal file
87
vfs/readervfs/reader_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package readervfs
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewSizeReaderAt(t *testing.T) {
|
||||
f, err := os.Create(filepath.Join(t.TempDir(), "abc.txt"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
n, err := NewSizeReaderAt(f).Size()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
reader := strings.NewReader("abc")
|
||||
|
||||
n, err = NewSizeReaderAt(reader).Size()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
n, err = NewSizeReaderAt(readlener{reader, reader.Len()}).Size()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
n, err = NewSizeReaderAt(readsizer{reader, reader.Size()}).Size()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
n, err = NewSizeReaderAt(readseeker{reader, reader}).Size()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("got %d", n)
|
||||
}
|
||||
|
||||
_, err = NewSizeReaderAt(readerat{reader}).Size()
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
}
|
||||
|
||||
type readlener struct {
|
||||
io.ReaderAt
|
||||
len int
|
||||
}
|
||||
|
||||
func (l readlener) Len() int { return l.len }
|
||||
|
||||
type readsizer struct {
|
||||
io.ReaderAt
|
||||
size int64
|
||||
}
|
||||
|
||||
func (l readsizer) Size() (int64, error) { return l.size, nil }
|
||||
|
||||
type readseeker struct {
|
||||
io.ReaderAt
|
||||
io.Seeker
|
||||
}
|
||||
|
||||
type readerat struct {
|
||||
io.ReaderAt
|
||||
}
|
||||
BIN
vfs/readervfs/testdata/test.db
vendored
Normal file
BIN
vfs/readervfs/testdata/test.db
vendored
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user