Compare commits

...

41 Commits

Author SHA1 Message Date
Nuno Cruces
dcc845d684 wazero v1.2.1. 2023-06-15 03:43:25 +01:00
dependabot[bot]
f1b42c26d5 Bump golang.org/x/sync from 0.2.0 to 0.3.0
Bumps [golang.org/x/sync](https://github.com/golang/sync) from 0.2.0 to 0.3.0.
- [Commits](https://github.com/golang/sync/compare/v0.2.0...v0.3.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sync
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-15 00:13:43 +01:00
dependabot[bot]
1e94407ae7 Bump golang.org/x/sys from 0.8.0 to 0.9.0
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.8.0 to 0.9.0.
- [Commits](https://github.com/golang/sys/compare/v0.8.0...v0.9.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-13 00:31:02 +01:00
Nuno Cruces
eb8d9b95fd Consistent lock timeouts. 2023-06-12 13:04:37 +01:00
Nuno Cruces
04037a75ed GORM driver sync. 2023-06-12 10:56:03 +01:00
Nuno Cruces
2472ceb0a0 Fix GORM module name. 2023-06-07 12:40:18 +01:00
Nuno Cruces
bfe9bfde2e Make GORM driver its own module. 2023-06-07 12:00:46 +01:00
Nuno Cruces
f07e82e361 GORM driver. 2023-06-06 12:37:54 +01:00
Nuno Cruces
fbbbe5a631 Fix plain files. 2023-06-06 03:47:02 +01:00
Nuno Cruces
5ea603ed78 Readers should not close. 2023-06-02 15:00:12 +01:00
Nuno Cruces
401cb77e38 binaryen-version_113. 2023-06-02 14:23:12 +01:00
Nuno Cruces
6511175011 wazero v1.2.0. 2023-06-02 14:11:20 +01:00
Nuno Cruces
f7d987fdf1 Commit phase-two API. 2023-06-02 13:40:08 +01:00
Nuno Cruces
00ba681bb5 Batch atomic writes API. 2023-06-02 11:14:34 +01:00
Nuno Cruces
d4d4533a41 Docs. 2023-06-02 03:38:26 +01:00
Nuno Cruces
ec9533b13f Implement modeof. 2023-06-02 03:38:26 +01:00
Nuno Cruces
8fe77a065c Remove wzprof. 2023-06-02 03:38:26 +01:00
Nuno Cruces
7bf5312bd4 Rename. 2023-06-02 03:38:26 +01:00
Nuno Cruces
ae7b74d858 Upgrade wzprof. 2023-06-01 16:09:18 +01:00
Nuno Cruces
9a8de3ad13 Enable memdb on speedtest1. 2023-06-01 15:41:20 +01:00
Nuno Cruces
05737e6025 Refactor reader VFS API. 2023-05-31 19:24:41 +01:00
Nuno Cruces
ac2836bb82 Refactor memdb API. 2023-05-31 16:27:31 +01:00
Nuno Cruces
d0d4b0e1a2 MemoryVFS mutexes. 2023-05-31 12:57:18 +01:00
Nuno Cruces
dc3dc6853d MemoryVFS journal. 2023-05-31 11:56:48 +01:00
dependabot[bot]
830240c368 Bump github.com/stealthrocket/wzprof from 0.1.3 to 0.1.4
Bumps [github.com/stealthrocket/wzprof](https://github.com/stealthrocket/wzprof) from 0.1.3 to 0.1.4.
- [Release notes](https://github.com/stealthrocket/wzprof/releases)
- [Changelog](https://github.com/stealthrocket/wzprof/blob/main/.goreleaser.yml)
- [Commits](https://github.com/stealthrocket/wzprof/compare/v0.1.3...v0.1.4)

---
updated-dependencies:
- dependency-name: github.com/stealthrocket/wzprof
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-05-30 13:40:52 +01:00
Nuno Cruces
dedec8682b Driver improvements. 2023-05-30 13:39:34 +01:00
Nuno Cruces
a33b828e13 Examples, tests, max size. 2023-05-30 11:21:14 +01:00
Nuno Cruces
8b2e96dedc Tests, fixes, docs. 2023-05-29 16:52:43 +01:00
Nuno Cruces
f1c46db512 VFS locking. 2023-05-27 23:36:39 +01:00
Nuno Cruces
7ca9d79424 MemoryVFS. 2023-05-27 23:36:39 +01:00
Nuno Cruces
254d473546 VFS URI parameters. 2023-05-27 23:36:39 +01:00
Nuno Cruces
5639fc1ff8 Update wzprof. 2023-05-27 23:36:39 +01:00
Nuno Cruces
ae4954d09b Profile with wzprof. 2023-05-27 23:36:39 +01:00
Nuno Cruces
45937d9749 Use wazerotest. 2023-05-27 23:36:39 +01:00
Nuno Cruces
eee71e06aa Tweak calling convention. 2023-05-25 17:03:40 +01:00
Nuno Cruces
9e7b6bb8ea Improve connection setup. 2023-05-25 11:14:18 +01:00
Nuno Cruces
597178f80d Backup fix, tests. 2023-05-24 02:47:18 +01:00
Nuno Cruces
cc2d16ac83 ReaderVFS. 2023-05-23 16:34:09 +01:00
Nuno Cruces
cfb69e4ce7 Reorg. 2023-05-23 14:47:39 +01:00
Nuno Cruces
e6969432e3 Rename. 2023-05-23 14:47:38 +01:00
Nuno Cruces
2b3da350cc Improved error handling. 2023-05-23 14:47:38 +01:00
94 changed files with 3214 additions and 861 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
custom: https://www.paypal.com/donate/buttons/manage/33P59ELZWGMK6
custom: https://www.paypal.com/donate?hosted_button_id=33P59ELZWGMK6

View File

@@ -34,8 +34,9 @@ jobs:
- name: Download
run: go mod download
- name: Verify
run: go mod verify
# Fixed in go 1.21: https://go.dev/issue/54372
# - name: Verify
# run: go mod verify
- name: Vet
run: go vet ./...

View File

@@ -15,14 +15,16 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
([example usage](https://pkg.go.dev/github.com/ncruces/go-sqlite3/driver#example-package)).
- Package [`github.com/ncruces/go-sqlite3/embed`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/embed)
embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/sqlite3vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/sqlite3vfs)
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Caveats
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS)
with a [pure Go](sqlite3vfs/) implementation.
This has numerous benefits, but also comes with some drawbacks.
This module replaces the SQLite [OS Interface](https://www.sqlite.org/vfs.html)
(aka VFS) with a [pure Go](vfs/) implementation.
This has benefits, but also comes with some drawbacks.
#### Write-Ahead Logging
@@ -57,7 +59,7 @@ BSD locks may _not_ be compatible with process-associated POSIX advisory locks.
#### Testing
The pure Go VFS is stress tested by running an unmodified build of SQLite's
The pure Go VFS is tested by running an unmodified build of SQLite's
[mptest](https://github.com/sqlite/sqlite/blob/master/mptest/mptest.c)
on Linux, macOS and Windows.
Performance is tested by running
@@ -70,12 +72,12 @@ Performance is tested by running
- [x] incremental BLOB I/O
- [x] online backup
- [ ] session extension
- [ ] custom SQL functions
- [ ] custom VFSes
- [ ] in-memory VFS
- [ ] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [x] custom VFS API
- [x] in-memory VFS
- [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
- [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki)
- [ ] custom SQL functions
### Alternatives

View File

@@ -74,16 +74,16 @@ func (c *Conn) backupInit(dst uint32, dstName string, src uint32, srcName string
r := c.call(c.api.backupInit,
uint64(dst), uint64(dstPtr),
uint64(src), uint64(srcPtr))
if r[0] == 0 {
if r == 0 {
defer c.closeDB(other)
r = c.call(c.api.errcode, uint64(dst))
return nil, c.module.error(r[0], dst)
return nil, c.module.error(r, dst)
}
return &Backup{
c: c,
otherc: other,
handle: uint32(r[0]),
handle: uint32(r),
}, nil
}
@@ -100,7 +100,7 @@ func (b *Backup) Close() error {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
b.c.closeDB(b.otherc)
b.handle = 0
return b.c.error(r[0])
return b.c.error(r)
}
// Step copies up to nPage pages between the source and destination databases.
@@ -109,10 +109,10 @@ func (b *Backup) Close() error {
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep
func (b *Backup) Step(nPage int) (done bool, err error) {
r := b.c.call(b.c.api.backupStep, uint64(b.handle), uint64(nPage))
if r[0] == _DONE {
if r == _DONE {
return true, nil
}
return false, b.c.error(r[0])
return false, b.c.error(r)
}
// Remaining returns the number of pages still to be backed up
@@ -121,7 +121,7 @@ func (b *Backup) Step(nPage int) (done bool, err error) {
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupremaining
func (b *Backup) Remaining() int {
r := b.c.call(b.c.api.backupRemaining, uint64(b.handle))
return int(r[0])
return int(r)
}
// PageCount returns the total number of pages in the source database
@@ -129,6 +129,6 @@ func (b *Backup) Remaining() int {
//
// https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backuppagecount
func (b *Backup) PageCount() int {
r := b.c.call(b.c.api.backupFinish, uint64(b.handle))
return int(r[0])
r := b.c.call(b.c.api.backupPageCount, uint64(b.handle))
return int(r)
}

18
blob.go
View File

@@ -45,13 +45,13 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr))
if err := c.error(r[0]); err != nil {
if err := c.error(r); err != nil {
return nil, err
}
blob := Blob{c: c}
blob.handle = util.ReadUint32(c.mod, blobPtr)
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0])
blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle)))
return &blob, nil
}
@@ -68,7 +68,7 @@ func (b *Blob) Close() error {
r := b.c.call(b.c.api.blobClose, uint64(b.handle))
b.handle = 0
return b.c.error(r[0])
return b.c.error(r)
}
// Size returns the size of the BLOB in bytes.
@@ -97,7 +97,7 @@ func (b *Blob) Read(p []byte) (n int, err error) {
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
err = b.c.error(r)
if err != nil {
return 0, err
}
@@ -130,7 +130,7 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) {
for want > 0 {
r := b.c.call(b.c.api.blobRead, uint64(b.handle),
uint64(ptr), uint64(want), uint64(b.offset))
err = b.c.error(r[0])
err = b.c.error(r)
if err != nil {
return n, err
}
@@ -163,7 +163,7 @@ func (b *Blob) Write(p []byte) (n int, err error) {
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(len(p)), uint64(b.offset))
err = b.c.error(r[0])
err = b.c.error(r)
if err != nil {
return 0, err
}
@@ -193,7 +193,7 @@ func (b *Blob) ReadFrom(r io.Reader) (n int64, err error) {
if m > 0 {
r := b.c.call(b.c.api.blobWrite, uint64(b.handle),
uint64(ptr), uint64(m), uint64(b.offset))
err := b.c.error(r[0])
err := b.c.error(r)
if err != nil {
return n, err
}
@@ -240,8 +240,8 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) {
//
// https://www.sqlite.org/c3ref/blob_reopen.html
func (b *Blob) Reopen(row int64) error {
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row))[0])
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle))[0])
err := b.c.error(b.c.call(b.c.api.blobReopen, uint64(b.handle), uint64(row)))
b.bytes = int64(b.c.call(b.c.api.blobBytes, uint64(b.handle)))
b.offset = 0
return err
}

30
conn.go
View File

@@ -80,7 +80,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0)
handle := util.ReadUint32(c.mod, connPtr)
if err := c.module.error(r[0], handle); err != nil {
if err := c.module.error(r, handle); err != nil {
c.closeDB(handle)
return 0, err
}
@@ -99,7 +99,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
c.arena.reset()
pragmaPtr := c.arena.string(pragmas.String())
r := c.call(c.api.exec, uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.module.error(r[0], handle, pragmas.String()); err != nil {
if err := c.module.error(r, handle, pragmas.String()); err != nil {
if errors.Is(err, ERROR) {
err = fmt.Errorf("sqlite3: invalid _pragma: %w", err)
}
@@ -113,7 +113,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
func (c *Conn) closeDB(handle uint32) {
r := c.call(c.api.closeZombie, uint64(handle))
if err := c.module.error(r[0], handle); err != nil {
if err := c.module.error(r, handle); err != nil {
panic(err)
}
}
@@ -137,7 +137,7 @@ func (c *Conn) Close() error {
c.pending = nil
r := c.call(c.api.close, uint64(c.handle))
if err := c.error(r[0]); err != nil {
if err := c.error(r); err != nil {
return err
}
@@ -156,7 +156,7 @@ func (c *Conn) Exec(sql string) error {
sqlPtr := c.arena.string(sql)
r := c.call(c.api.exec, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r[0])
return c.error(r)
}
// Prepare calls [Conn.PrepareFlags] with no flags.
@@ -189,7 +189,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
i := util.ReadUint32(c.mod, tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
if err := c.error(r, sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
@@ -203,7 +203,7 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str
// https://www.sqlite.org/c3ref/get_autocommit.html
func (c *Conn) GetAutocommit() bool {
r := c.call(c.api.autocommit, uint64(c.handle))
return r[0] != 0
return r != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
@@ -212,7 +212,7 @@ func (c *Conn) GetAutocommit() bool {
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() int64 {
r := c.call(c.api.lastRowid, uint64(c.handle))
return int64(r[0])
return int64(r)
}
// Changes returns the number of rows modified, inserted or deleted
@@ -222,7 +222,7 @@ func (c *Conn) LastInsertRowID() int64 {
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() int64 {
r := c.call(c.api.changes, uint64(c.handle))
return int64(r[0])
return int64(r)
}
// SetInterrupt interrupts a long-running query when a context is done.
@@ -325,17 +325,21 @@ func (c *Conn) error(rc uint64, sql ...string) error {
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// It can be used to access advanced SQLite features like
// [savepoints] and [incremental BLOB I/O].
// [savepoints], [online backup] and [incremental BLOB I/O].
//
// [savepoints]: https://www.sqlite.org/lang_savepoint.html
// [online backup]: https://www.sqlite.org/backup.html
// [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html
type DriverConn interface {
driver.Conn
driver.ConnBeginTx
driver.ExecerContext
driver.ConnPrepareContext
Savepoint() Savepoint
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
SetInterrupt(ctx context.Context) (old context.Context)
Savepoint() Savepoint
Backup(srcDB, dstURI string) error
Restore(dstDB, srcURI string) error
OpenBlob(db, table, column string, row int64, write bool) (*Blob, error)
}

View File

@@ -19,6 +19,9 @@
// If no PRAGMAs are specified, a busy timeout of 1 minute
// and normal locking mode are used.
//
// Order matters:
// busy timeout and locking mode should be the first PRAGMAs set, in that order.
//
// [URI]: https://www.sqlite.org/uri.html
// [PRAGMA]: https://www.sqlite.org/pragma.html
// [TRANSACTION]: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
@@ -46,11 +49,12 @@ type sqlite struct{}
func (sqlite) Open(name string) (_ driver.Conn, err error) {
var c conn
c.conn, err = sqlite3.Open(name)
c.Conn, err = sqlite3.Open(name)
if err != nil {
return nil, err
}
c.txBegin = "BEGIN"
var pragmas []string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
@@ -70,9 +74,9 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
}
if len(pragmas) == 0 {
err := c.conn.Exec(`
PRAGMA locking_mode=normal;
err := c.Conn.Exec(`
PRAGMA busy_timeout=60000;
PRAGMA locking_mode=normal;
`)
if err != nil {
c.Close()
@@ -80,7 +84,7 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
}
c.reusable = true
} else {
s, _, err := c.conn.Prepare(`
s, _, err := c.Conn.Prepare(`
SELECT * FROM
PRAGMA_locking_mode,
PRAGMA_query_only;
@@ -99,11 +103,11 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
return nil, err
}
}
return c, nil
return &c, nil
}
type conn struct {
conn *sqlite3.Conn
*sqlite3.Conn
txBegin string
txCommit string
txRollback string
@@ -113,25 +117,21 @@ type conn struct {
var (
// Ensure these interfaces are implemented:
_ driver.ExecerContext = conn{}
_ driver.ConnBeginTx = conn{}
_ driver.Validator = conn{}
_ sqlite3.DriverConn = conn{}
_ driver.ExecerContext = &conn{}
_ driver.ConnBeginTx = &conn{}
_ driver.Validator = &conn{}
_ sqlite3.DriverConn = &conn{}
)
func (c conn) Close() error {
return c.conn.Close()
}
func (c conn) IsValid() bool {
func (c *conn) IsValid() bool {
return c.reusable
}
func (c conn) Begin() (driver.Tx, error) {
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
txBegin := c.txBegin
c.txCommit = `COMMIT`
c.txRollback = `ROLLBACK`
@@ -155,33 +155,43 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro
break
}
err := c.conn.Exec(txBegin)
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.Conn.Exec(txBegin)
if err != nil {
return nil, err
}
return c, nil
}
func (c conn) Commit() error {
err := c.conn.Exec(c.txCommit)
if err != nil && !c.conn.GetAutocommit() {
func (c *conn) Commit() error {
err := c.Conn.Exec(c.txCommit)
if err != nil && !c.GetAutocommit() {
c.Rollback()
}
return err
}
func (c conn) Rollback() error {
return c.conn.Exec(c.txRollback)
func (c *conn) Rollback() error {
return c.Conn.Exec(c.txRollback)
}
func (c conn) Prepare(query string) (driver.Stmt, error) {
s, tail, err := c.conn.Prepare(query)
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
s, tail, err := c.Conn.Prepare(query)
if err != nil {
return nil, err
}
if tail != "" {
// Check if the tail contains any SQL.
st, _, err := c.conn.Prepare(tail)
st, _, err := c.Conn.Prepare(tail)
if err != nil {
s.Close()
return nil, err
@@ -192,62 +202,46 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
return nil, util.TailErr
}
}
return stmt{s, c.conn}, nil
return &stmt{s, c.Conn}, nil
}
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
return c.Prepare(query)
}
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if len(args) != 0 {
// Slow path.
return nil, driver.ErrSkip
}
old := c.conn.SetInterrupt(ctx)
defer c.conn.SetInterrupt(old)
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
err := c.conn.Exec(query)
err := c.Conn.Exec(query)
if err != nil {
return nil, err
}
return newResult(c.conn), nil
}
func (c conn) Savepoint() sqlite3.Savepoint {
return c.conn.Savepoint()
}
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
return c.conn.OpenBlob(db, table, column, row, write)
}
func (c conn) SetInterrupt(ctx context.Context) (old context.Context) {
return c.conn.SetInterrupt(ctx)
return newResult(c.Conn), nil
}
type stmt struct {
stmt *sqlite3.Stmt
conn *sqlite3.Conn
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
}
var (
// Ensure these interfaces are implemented:
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
_ driver.NamedValueChecker = stmt{}
_ driver.StmtExecContext = &stmt{}
_ driver.StmtQueryContext = &stmt{}
_ driver.NamedValueChecker = &stmt{}
)
func (s stmt) Close() error {
return s.stmt.Close()
func (s *stmt) Close() error {
return s.Stmt.Close()
}
func (s stmt) NumInput() int {
n := s.stmt.BindCount()
func (s *stmt) NumInput() int {
n := s.Stmt.BindCount()
for i := 1; i <= n; i++ {
if s.stmt.BindName(i) != "" {
if s.Stmt.BindName(i) != "" {
return -1
}
}
@@ -255,16 +249,16 @@ func (s stmt) NumInput() int {
}
// Deprecated: use ExecContext instead.
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), namedValues(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
// Use QueryContext to setup bindings.
// No need to close rows: that simply resets the statement, exec does the same.
_, err := s.QueryContext(ctx, args)
@@ -272,16 +266,16 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
return nil, err
}
err = s.stmt.Exec()
err = s.Stmt.Exec()
if err != nil {
return nil, err
}
return newResult(s.conn), nil
return newResult(s.Conn), nil
}
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.stmt.ClearBindings()
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.Stmt.ClearBindings()
if err != nil {
return nil, err
}
@@ -293,7 +287,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range []string{":", "@", "$"} {
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
@@ -302,23 +296,23 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
err = s.Stmt.BindBool(id, a)
case int:
err = s.stmt.BindInt(id, a)
err = s.Stmt.BindInt(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
err = s.Stmt.BindInt64(id, a)
case float64:
err = s.stmt.BindFloat(id, a)
err = s.Stmt.BindFloat(id, a)
case string:
err = s.stmt.BindText(id, a)
err = s.Stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
err = s.Stmt.BindBlob(id, a)
case sqlite3.ZeroBlob:
err = s.stmt.BindZeroBlob(id, int64(a))
err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case nil:
err = s.stmt.BindNull(id)
err = s.Stmt.BindNull(id)
default:
panic(util.AssertErr())
}
@@ -328,10 +322,10 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
}
}
return rows{ctx, s.stmt, s.conn}, nil
return &rows{ctx, s.Stmt, s.Conn}, nil
}
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil:
@@ -374,44 +368,44 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
conn *sqlite3.Conn
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
}
func (r rows) Close() error {
return r.stmt.Reset()
func (r *rows) Close() error {
return r.Stmt.Reset()
}
func (r rows) Columns() []string {
count := r.stmt.ColumnCount()
func (r *rows) Columns() []string {
count := r.Stmt.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.stmt.ColumnName(i)
columns[i] = r.Stmt.ColumnName(i)
}
return columns
}
func (r rows) Next(dest []driver.Value) error {
old := r.conn.SetInterrupt(r.ctx)
defer r.conn.SetInterrupt(old)
func (r *rows) Next(dest []driver.Value) error {
old := r.Conn.SetInterrupt(r.ctx)
defer r.Conn.SetInterrupt(old)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {
return err
}
return io.EOF
}
for i := range dest {
switch r.stmt.ColumnType(i) {
switch r.Stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.stmt.ColumnInt64(i)
dest[i] = r.Stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.stmt.ColumnFloat(i)
dest[i] = r.Stmt.ColumnFloat(i)
case sqlite3.BLOB:
dest[i] = r.stmt.ColumnRawBlob(i)
dest[i] = r.Stmt.ColumnRawBlob(i)
case sqlite3.TEXT:
dest[i] = stringOrTime(r.stmt.ColumnRawText(i))
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
@@ -423,5 +417,5 @@ func (r rows) Next(dest []driver.Value) error {
}
}
return r.stmt.Err()
return r.Stmt.Err()
}

View File

@@ -17,7 +17,8 @@ The following optional features are compiled in:
- [uuid](https://github.com/sqlite/sqlite/blob/master/ext/misc/uuid.c)
- [time](../sqlite3/time.c)
See the [configuration options](../sqlite3/sqlite_cfg.h).
See the [configuration options](../sqlite3/sqlite_cfg.h),
and [patches](../sqlite3) applied.
Built using [`wasi-sdk`](https://github.com/WebAssembly/wasi-sdk),
and [`binaryen`](https://github.com/WebAssembly/binaryen).

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

View File

@@ -37,11 +37,13 @@ sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_backup_pagecount
sqlite3_uri_parameter
sqlite3_uri_key
sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit

Binary file not shown.

View File

@@ -122,7 +122,7 @@ func Test_ErrorCode_Error(t *testing.T) {
for i := 0; i == int(ErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
got := ErrorCode(i).Error()
if got != want {
@@ -144,7 +144,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
want := "sqlite3: "
r := db.call(db.api.errstr, uint64(i))
want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING)
want += util.ReadString(db.mod, uint32(r), _MAX_STRING)
got := ExtendedErrorCode(i).Error()
if got != want {

7
go.mod
View File

@@ -4,9 +4,10 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.5
github.com/tetratelabs/wazero v1.1.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.0
github.com/psanford/httpreadat v0.1.0
github.com/tetratelabs/wazero v1.2.1
golang.org/x/sync v0.3.0
golang.org/x/sys v0.9.0
)
retract v0.4.0 // tagged from the wrong branch

14
go.sum
View File

@@ -1,8 +1,10 @@
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ=
github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

6
go.work Normal file
View File

@@ -0,0 +1,6 @@
go 1.19
use (
.
./gormlite
)

22
gormlite/LICENSE Normal file
View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2023 Nuno Cruces
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

26
gormlite/README.md Normal file
View File

@@ -0,0 +1,26 @@
# GORM SQLite Driver
[![Go Reference](https://pkg.go.dev/badge/image)](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
## Usage
```go
import (
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/gormlite"
"gorm.io/gorm"
)
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
```
Checkout [https://gorm.io](https://gorm.io) for details.
### Foreign-key constraint activation
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
```go
db, err := gorm.Open(gormlite.Open(
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
&gorm.Config{})
```

231
gormlite/ddlmod.go Normal file
View File

@@ -0,0 +1,231 @@
package gormlite
import (
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"gorm.io/gorm/migrator"
)
var (
sqliteSeparator = "`|\"|'|\t"
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
)
func getAllColumns(s string) []string {
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
columns := make([]string, 0, len(allMatches))
for _, matches := range allMatches {
if len(matches) > 1 {
columns = append(columns, matches[1])
}
}
return columns
}
type ddl struct {
head string
fields []string
columns []migrator.ColumnType
}
func parseDDL(strs ...string) (*ddl, error) {
var result ddl
for _, str := range strs {
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
var (
ddlBody = sections[2]
ddlBodyRunes = []rune(ddlBody)
bracketLevel int
quote rune
buf string
)
ddlBodyRunesLen := len(ddlBodyRunes)
result.head = sections[1]
for idx := 0; idx < ddlBodyRunesLen; idx++ {
var (
next rune = 0
c = ddlBodyRunes[idx]
)
if idx+1 < ddlBodyRunesLen {
next = ddlBodyRunes[idx+1]
}
if sc := string(c); separatorRegexp.MatchString(sc) {
if c == next {
buf += sc // Skip escaped quote
idx++
} else if quote > 0 {
quote = 0
} else {
quote = c
}
} else if quote == 0 {
if c == '(' {
bracketLevel++
} else if c == ')' {
bracketLevel--
} else if bracketLevel == 0 {
if c == ',' {
result.fields = append(result.fields, strings.TrimSpace(buf))
buf = ""
continue
}
}
}
if bracketLevel < 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
buf += string(c)
}
if bracketLevel != 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
if buf != "" {
result.fields = append(result.fields, strings.TrimSpace(buf))
}
for _, f := range result.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") {
continue
}
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
for _, name := range getAllColumns(f) {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
columnType := migrator.ColumnType{
NameValue: sql.NullString{String: matches[1], Valid: true},
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
matchUpper := strings.ToUpper(matches[3])
if strings.Contains(matchUpper, " NOT NULL") {
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
} else if strings.Contains(matchUpper, " NULL") {
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " UNIQUE") {
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " PRIMARY") {
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
}
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
if strings.ToLower(defaultMatches[1]) != "null" {
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
}
}
// data type length
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
if len(matches) == 1 && len(matches[0]) == 2 {
size, _ := strconv.Atoi(matches[0][1])
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
}
result.columns = append(result.columns, columnType)
}
}
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
for _, column := range getAllColumns(matches[1]) {
for idx, c := range result.columns {
if c.NameValue.String == column {
c.UniqueValue = sql.NullBool{Bool: strings.ToUpper(strings.Fields(str)[1]) == "UNIQUE", Valid: true}
result.columns[idx] = c
}
}
}
} else {
return nil, errors.New("invalid DDL")
}
}
return &result, nil
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
}
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return
}
}
d.fields = append(d.fields, sql)
}
func (d *ddl) removeConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
for _, f := range d.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
continue
}
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
match := reg.FindStringSubmatch(f)
if match != nil {
res = append(res, "`"+match[1]+"`")
}
}
return res
}

352
gormlite/ddlmod_test.go Normal file
View File

@@ -0,0 +1,352 @@
package gormlite
import (
"database/sql"
"testing"
"gorm.io/gorm/migrator"
"gorm.io/gorm/utils/tests"
)
func TestParseDDL(t *testing.T) {
params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{"with_fk", []string{
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"no brackets", []string{"create table test"}, 0, nil},
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
"table_name_with_dash",
[]string{
"CREATE TABLE `test-a` (`id` int NOT NULL)",
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "int", Valid: true},
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
},
},
},
{
"unique index",
[]string{
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
{
"non-unique index",
[]string{
"CREATE TABLE `test-c` (`field` integer NOT NULL)",
"CREATE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)
if err != nil {
panic(err.Error())
}
tests.AssertEqual(t, p.sql[0], ddl.compile())
if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}
func TestParseDDL_Whitespaces(t *testing.T) {
testColumns := []migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
},
{
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
},
}
params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{
"with_newline",
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_newline_2",
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_missing_space",
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_many_spaces",
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)
if err != nil {
panic(err.Error())
}
if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}
func TestParseDDL_error(t *testing.T) {
params := []struct {
name string
sql string
}{
{"invalid_cmd", "CREATE TABLE"},
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
_, err := parseDDL(p.sql)
if err == nil {
t.Fail()
}
})
}
}
func TestAddConstraint(t *testing.T) {
params := []struct {
name string
fields []string
cName string
sql string
expect []string
}{
{
name: "add_new",
fields: []string{"`id` integer NOT NULL"},
cName: "fk_users_notes",
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
},
{
name: "update",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
cName: "fk_users_notes",
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
},
{
name: "add_check",
fields: []string{"`id` integer NOT NULL"},
cName: "name_checker",
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
},
{
name: "update_check",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
cName: "name_checker",
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL := ddl{fields: p.fields}
testDDL.addConstraint(p.cName, p.sql)
tests.AssertEqual(t, p.expect, testDDL.fields)
})
}
}
func TestRemoveConstraint(t *testing.T) {
params := []struct {
name string
fields []string
cName string
success bool
expect []string
}{
{
name: "fk",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
cName: "fk_users_notes",
success: true,
expect: []string{"`id` integer NOT NULL"},
},
{
name: "check",
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
cName: "name_checker",
success: true,
expect: []string{"`id` integer NOT NULL"},
},
{
name: "none",
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
cName: "nothing",
success: false,
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL := ddl{fields: p.fields}
success := testDDL.removeConstraint(p.cName)
tests.AssertEqual(t, p.success, success)
tests.AssertEqual(t, p.expect, testDDL.fields)
})
}
}
func TestGetColumns(t *testing.T) {
params := []struct {
name string
ddl string
columns []string
}{
{
name: "with_fk",
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
columns: []string{"`id`", "`text`", "`user_id`"},
},
{
name: "with_check",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
},
{
name: "with_escaped_quote",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
},
{
name: "with_generated_column",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
},
{
name: "with_new_line",
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
"id" integer PRIMARY KEY AUTOINCREMENT,
"created_at" datetime NOT NULL,
"updated_at" datetime NOT NULL,
"created_by" integer NOT NULL DEFAULT 0,
"updated_by" integer NOT NULL DEFAULT 0,
"role_id" integer NOT NULL,
"menu_id" bigint NOT NULL
)`,
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL, err := parseDDL(p.ddl)
if err != nil {
panic(err.Error())
}
cols := testDDL.getColumns()
tests.AssertEqual(t, p.columns, cols)
})
}
}

11
gormlite/download.sh Executable file
View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/ddlmod_test.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/error_translator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/migrator.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite.go"
curl -#OL "https://github.com/go-gorm/sqlite/raw/master/sqlite_test.go"

View File

@@ -0,0 +1,21 @@
package gormlite
import (
"errors"
"github.com/ncruces/go-sqlite3"
"gorm.io/gorm"
)
func (dialector Dialector) Translate(err error) error {
switch {
case
errors.Is(err, sqlite3.CONSTRAINT_UNIQUE),
errors.Is(err, sqlite3.CONSTRAINT_PRIMARYKEY):
return gorm.ErrDuplicatedKey
case
errors.Is(err, sqlite3.CONSTRAINT_FOREIGNKEY):
return err // gorm.ErrForeignKeyViolated (gorm v1.25.2)
}
return err
}

16
gormlite/go.mod Normal file
View File

@@ -0,0 +1,16 @@
module github.com/ncruces/go-sqlite3/gormlite
go 1.19
require (
github.com/ncruces/go-sqlite3 v0.7.2
gorm.io/gorm v1.25.1
)
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/ncruces/julianday v0.1.5 // indirect
github.com/tetratelabs/wazero v1.2.1 // indirect
golang.org/x/sys v0.9.0 // indirect
)

14
gormlite/go.sum Normal file
View File

@@ -0,0 +1,14 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/go-sqlite3 v0.7.2 h1:K7jU4rnUxFdUsbEL+B0Xc+VexLTEwGSO6Qh91Qh4hYc=
github.com/ncruces/go-sqlite3 v0.7.2/go.mod h1:t3dP4AP9rJddU+ffFv0h6fWyeOCEhjxrYc1nsYG7aQI=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/tetratelabs/wazero v1.2.1 h1:J4X2hrGzJvt+wqltuvcSjHQ7ujQxA9gb6PeMs4qlUWs=
github.com/tetratelabs/wazero v1.2.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

431
gormlite/migrator.go Normal file
View File

@@ -0,0 +1,431 @@
package gormlite
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
m.DB.Exec("PRAGMA foreign_keys = OFF")
defer m.DB.Exec("PRAGMA foreign_keys = ON")
}
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return nil
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
sqls []string
sqlDDL *ddl
)
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
return err
}
if sqlDDL, err = parseDDL(sqls...); err != nil {
return err
}
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnType := migrator.ColumnType{SQLColumnType: c}
for _, column := range sqlDDL.columns {
if column.NameValue.String == c.Name() {
column.SQLColumnType = c
columnType = column
break
}
}
columnTypes = append(columnTypes, columnType)
}
return err
})
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, "")
return createSQL, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
var (
constraintName string
constraintSql string
constraintValues []interface{}
)
if constraint != nil {
constraintName = constraint.Name
constraintSql, constraintValues = buildConstraint(constraint)
} else if chk != nil {
constraintName = chk.Name
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
}
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()
return createSQL, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()
return createSQL, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
).Row().Scan(&count)
return nil
})
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
if err := m.DropIndex(value, oldName); err != nil {
return err
}
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (m Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
if m.DB.Error != nil {
return "", m.DB.Error
}
return createSQL, nil
}
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
table = *tablePtr
}
rawDDL, err := m.getRawDDL(table)
if err != nil {
return err
}
newTableName := table + "__temp"
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
if err != nil {
return err
}
if createSQL == "" {
return nil
}
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createDDL, err := parseDDL(createSQL)
if err != nil {
return err
}
columns := createDDL.getColumns()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
return err
}
queries := []string{
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
fmt.Sprintf("DROP TABLE `%v`", table),
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
}
for _, query := range queries {
if err := tx.Exec(query).Error; err != nil {
return err
}
}
return nil
})
})
}

219
gormlite/sqlite.go Normal file
View File

@@ -0,0 +1,219 @@
// Package gormlite provides a GORM driver for SQLite.
package gormlite
import (
"context"
"database/sql"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
if err != nil {
return err
}
db.ConnPool = conn
}
var version string
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
return err
}
// https://www.sqlite.org/releaselog/3_35_0.html
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
} else {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
LastInsertIDReversed: true,
})
}
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
stmt.WriteString("INSERT ")
if insert.Modifier != "" {
stmt.WriteString(insert.Modifier)
stmt.WriteByte(' ')
}
stmt.WriteString("INTO ")
if insert.Table.Name == "" {
stmt.WriteQuoted(stmt.Table)
} else {
stmt.WriteQuoted(insert.Table)
}
return
}
}
c.Build(builder)
},
"LIMIT": func(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
var lmt = -1
if limit.Limit != nil && *limit.Limit >= 0 {
lmt = *limit.Limit
}
if lmt >= 0 || limit.Offset > 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(lmt))
}
if limit.Offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
}
},
"FOR": func(c clause.Clause, builder clause.Builder) {
if _, ok := c.Expression.(clause.Locking); ok {
// SQLite3 does not support row-level locking.
return
}
c.Build(builder)
},
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
// doesn't work, will raise error
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
if strings.Contains(str, ".") {
for idx, str := range strings.Split(str, ".") {
if idx > 0 {
writer.WriteString(".`")
}
writer.WriteString(str)
writer.WriteByte('`')
}
} else {
writer.WriteString(str)
writer.WriteByte('`')
}
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
return "integer"
}
case schema.Float:
return "real"
case schema.String:
return "text"
case schema.Time:
// Distinguish between schema.Time and tag time
if val, ok := field.TagSettings["TYPE"]; ok {
return val
} else {
return "datetime"
}
case schema.Bytes:
return "blob"
}
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func compareVersion(version1, version2 string) int {
n, m := len(version1), len(version2)
i, j := 0, 0
for i < n || j < m {
x := 0
for ; i < n && version1[i] != '.'; i++ {
x = x*10 + int(version1[i]-'0')
}
i++
y := 0
for ; j < m && version2[j] != '.'; j++ {
y = y*10 + int(version2[j]-'0')
}
j++
if x > y {
return 1
}
if x < y {
return -1
}
}
return 0
}

64
gormlite/sqlite_test.go Normal file
View File

@@ -0,0 +1,64 @@
package gormlite
import (
"fmt"
"testing"
"gorm.io/gorm"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
rows := []struct {
description string
dialector *Dialector
openSuccess bool
query string
querySuccess bool
}{
{
description: "Default driver",
dialector: &Dialector{
DSN: InMemoryDSN,
},
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
db, err := gorm.Open(row.dialector, &gorm.Config{})
if !row.openSuccess {
if err == nil {
t.Errorf("Expected Open to fail.")
}
return
}
if err != nil {
t.Errorf("Expected Open to succeed; got error: %v", err)
}
if db == nil {
t.Errorf("Expected db to be non-nil.")
}
if row.query != "" {
err = db.Exec(row.query).Error
if !row.querySuccess {
if err == nil {
t.Errorf("Expected query to fail.")
}
return
}
if err != nil {
t.Errorf("Expected query to succeed; got error: %v", err)
}
}
})
}
}

18
gormlite/test.sh Executable file
View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
set -euo pipefail
cd -P -- "$(dirname -- "$0")"
rm -rf gorm/ tests/
git clone --filter=blob:none --branch=v1.25.1 https://github.com/go-gorm/gorm.git
mv gorm/tests tests
rm -rf gorm/
patch -p1 -N < tests.patch
cd tests
go mod tidy && go work use . && go test
cd ..
rm -rf tests/
go work use -r .

63
gormlite/tests.patch Normal file
View File

@@ -0,0 +1,63 @@
diff --git a/tests/.gitignore b/tests/.gitignore
index 08cb523..72e8ffc 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -1 +1 @@
-go.sum
+*
diff --git a/tests/go.mod b/tests/go.mod
index f47d175..84b80c2 100644
--- a/tests/go.mod
+++ b/tests/go.mod
@@ -7,13 +7,13 @@ require (
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.8
- github.com/mattn/go-sqlite3 v1.14.16 // indirect
+ github.com/ncruces/go-sqlite3 v0.7.2
+ github.com/ncruces/go-sqlite3/gormlite v0.0.0
golang.org/x/crypto v0.8.0 // indirect
gorm.io/driver/mysql v1.5.0
gorm.io/driver/postgres v1.5.0
- gorm.io/driver/sqlite v1.5.0
gorm.io/driver/sqlserver v1.4.3
- gorm.io/gorm v1.25.0
+ gorm.io/gorm v1.25.1
)
-replace gorm.io/gorm => ../
+replace github.com/ncruces/go-sqlite3/gormlite => ../
diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go
index 1412169..472434b 100644
--- a/tests/scanner_valuer_test.go
+++ b/tests/scanner_valuer_test.go
@@ -170,10 +170,10 @@ func (data *EncryptedData) Scan(value interface{}) error {
return errors.New("Too short")
}
- *data = b[3:]
+ *data = append((*data)[0:], b[3:]...)
return nil
} else if s, ok := value.(string); ok {
- *data = []byte(s)[3:]
+ *data = []byte(s[3:])
return nil
}
diff --git a/tests/tests_test.go b/tests/tests_test.go
index 90eb847..cd9af43 100644
--- a/tests/tests_test.go
+++ b/tests/tests_test.go
@@ -7,9 +7,11 @@ import (
"path/filepath"
"time"
+ _ "github.com/ncruces/go-sqlite3/embed"
+ sqlite "github.com/ncruces/go-sqlite3/gormlite"
+
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
- "gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"

22
internal/util/bool.go Normal file
View File

@@ -0,0 +1,22 @@
package util
import "strings"
func ParseBool(s string) (b, ok bool) {
if len(s) == 0 {
return false, false
}
if s[0] == '0' {
return false, true
}
if '1' <= s[0] && s[0] <= '9' {
return true, true
}
switch strings.ToLower(s) {
case "true", "yes", "on":
return true, true
case "false", "no", "off":
return false, true
}
return false, false
}

View File

@@ -0,0 +1,28 @@
package util
import "testing"
func TestParseBool(t *testing.T) {
tests := []struct {
str string
val bool
ok bool
}{
{"", false, false},
{"0", false, true},
{"1", true, true},
{"9", true, true},
{"T", false, false},
{"true", true, true},
{"FALSE", false, true},
{"false?", false, false},
}
for _, tt := range tests {
t.Run(tt.str, func(t *testing.T) {
gotVal, gotOK := ParseBool(tt.str)
if gotVal != tt.val || gotOK != tt.ok {
t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
}
})
}
}

View File

@@ -3,88 +3,90 @@ package util
import (
"math"
"testing"
"github.com/tetratelabs/wazero/experimental/wazerotest"
)
func TestView_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
View(mock, 0, 8)
t.Error("want panic")
}
func TestView_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
View(mock, 126, 8)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
View(mock, wazerotest.PageSize-2, 8)
t.Error("want panic")
}
func TestView_overflow(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
View(mock, 1, math.MaxInt64)
t.Error("want panic")
}
func TestReadUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint32(mock, 0)
t.Error("want panic")
}
func TestReadUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint32(mock, 126)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint32(mock, wazerotest.PageSize-2)
t.Error("want panic")
}
func TestReadUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint64(mock, 0)
t.Error("want panic")
}
func TestReadUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadUint64(mock, 126)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadUint64(mock, wazerotest.PageSize-2)
t.Error("want panic")
}
func TestWriteUint32_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint32(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint32_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint32(mock, 126, 1)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint32(mock, wazerotest.PageSize-2, 1)
t.Error("want panic")
}
func TestWriteUint64_nil(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint64(mock, 0, 1)
t.Error("want panic")
}
func TestWriteUint64_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
WriteUint64(mock, 126, 1)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
WriteUint64(mock, wazerotest.PageSize-2, 1)
t.Error("want panic")
}
func TestReadString_range(t *testing.T) {
defer func() { _ = recover() }()
mock := NewMockModule(128)
ReadString(mock, 130, math.MaxUint32)
mock := wazerotest.NewModule(wazerotest.NewFixedMemory(wazerotest.PageSize))
ReadString(mock, wazerotest.PageSize+2, math.MaxUint32)
t.Error("want panic")
}

View File

@@ -1,148 +0,0 @@
package util
import (
"encoding/binary"
"math"
"github.com/tetratelabs/wazero/api"
)
func NewMockModule(size uint32) api.Module {
mem := mockMemory{buf: make([]byte, size)}
return mockModule{&mem, nil}
}
type mockModule struct {
memory api.Memory
api.Module
}
func (m mockModule) Memory() api.Memory { return m.memory }
func (m mockModule) String() string { return "mockModule" }
func (m mockModule) Name() string { return "mockModule" }
type mockMemory struct {
buf []byte
api.Memory
}
func (m mockMemory) Definition() api.MemoryDefinition { return nil }
func (m mockMemory) Size() uint32 { return uint32(len(m.buf)) }
func (m mockMemory) ReadByte(offset uint32) (byte, bool) {
if offset >= m.Size() {
return 0, false
}
return m.buf[offset], true
}
func (m mockMemory) ReadUint16Le(offset uint32) (uint16, bool) {
if !m.hasSize(offset, 2) {
return 0, false
}
return binary.LittleEndian.Uint16(m.buf[offset : offset+2]), true
}
func (m mockMemory) ReadUint32Le(offset uint32) (uint32, bool) {
if !m.hasSize(offset, 4) {
return 0, false
}
return binary.LittleEndian.Uint32(m.buf[offset : offset+4]), true
}
func (m mockMemory) ReadFloat32Le(offset uint32) (float32, bool) {
v, ok := m.ReadUint32Le(offset)
if !ok {
return 0, false
}
return math.Float32frombits(v), true
}
func (m mockMemory) ReadUint64Le(offset uint32) (uint64, bool) {
if !m.hasSize(offset, 8) {
return 0, false
}
return binary.LittleEndian.Uint64(m.buf[offset : offset+8]), true
}
func (m mockMemory) ReadFloat64Le(offset uint32) (float64, bool) {
v, ok := m.ReadUint64Le(offset)
if !ok {
return 0, false
}
return math.Float64frombits(v), true
}
func (m mockMemory) Read(offset, byteCount uint32) ([]byte, bool) {
if !m.hasSize(offset, byteCount) {
return nil, false
}
return m.buf[offset : offset+byteCount : offset+byteCount], true
}
func (m mockMemory) WriteByte(offset uint32, v byte) bool {
if offset >= m.Size() {
return false
}
m.buf[offset] = v
return true
}
func (m mockMemory) WriteUint16Le(offset uint32, v uint16) bool {
if !m.hasSize(offset, 2) {
return false
}
binary.LittleEndian.PutUint16(m.buf[offset:], v)
return true
}
func (m mockMemory) WriteUint32Le(offset, v uint32) bool {
if !m.hasSize(offset, 4) {
return false
}
binary.LittleEndian.PutUint32(m.buf[offset:], v)
return true
}
func (m mockMemory) WriteFloat32Le(offset uint32, v float32) bool {
return m.WriteUint32Le(offset, math.Float32bits(v))
}
func (m mockMemory) WriteUint64Le(offset uint32, v uint64) bool {
if !m.hasSize(offset, 8) {
return false
}
binary.LittleEndian.PutUint64(m.buf[offset:], v)
return true
}
func (m mockMemory) WriteFloat64Le(offset uint32, v float64) bool {
return m.WriteUint64Le(offset, math.Float64bits(v))
}
func (m mockMemory) Write(offset uint32, val []byte) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m.buf[offset:], val)
return true
}
func (m mockMemory) WriteString(offset uint32, val string) bool {
if !m.hasSize(offset, uint32(len(val))) {
return false
}
copy(m.buf[offset:], val)
return true
}
func (m *mockMemory) Grow(delta uint32) (result uint32, ok bool) {
prev := (len(m.buf) + 65535) / 65536
m.buf = append(m.buf, make([]byte, 65536*delta)...)
return uint32(prev), true
}
func (m mockMemory) hasSize(offset uint32, byteCount uint32) bool {
return uint64(offset)+uint64(byteCount) <= uint64(len(m.buf))
}

View File

@@ -1,242 +0,0 @@
package util
import (
"math"
"testing"
)
func Test_mockMemory_byte(t *testing.T) {
const want byte = 98
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteByte(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadByte(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint16(t *testing.T) {
const want uint16 = 9876
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint16Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint16Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint16Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint32(t *testing.T) {
const want uint32 = 987654321
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_uint64(t *testing.T) {
const want uint64 = 9876543210
mock := NewMockModule(128)
_, ok := mock.Memory().ReadUint64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteUint64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadUint64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %d, want %d", got, want)
}
}
func Test_mockMemory_float32(t *testing.T) {
const want float32 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat32Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat32Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat32Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_float64(t *testing.T) {
const want float64 = math.Pi
mock := NewMockModule(128)
_, ok := mock.Memory().ReadFloat64Le(128)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(128, 0)
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteFloat64Le(0, want)
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().ReadFloat64Le(0)
if !ok {
t.Error("want ok")
}
if got != want {
t.Errorf("got %f, want %f", got, want)
}
}
func Test_mockMemory_bytes(t *testing.T) {
const want string = "\xca\xfe\xba\xbe"
mock := NewMockModule(128)
_, ok := mock.Memory().Read(128, uint32(len(want)))
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(128, []byte(want))
if ok {
t.Error("want error")
}
ok = mock.Memory().WriteString(128, want)
if ok {
t.Error("want error")
}
ok = mock.Memory().Write(0, []byte(want))
if !ok {
t.Error("want ok")
}
got, ok := mock.Memory().Read(0, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
ok = mock.Memory().WriteString(64, want)
if !ok {
t.Error("want ok")
}
got, ok = mock.Memory().Read(64, uint32(len(want)))
if !ok {
t.Error("want ok")
}
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func Test_mockMemory_grow(t *testing.T) {
mock := NewMockModule(128)
_, ok := mock.Memory().ReadByte(65536)
if ok {
t.Error("want error")
}
got, ok := mock.Memory().Grow(1)
if !ok {
t.Error("want ok")
}
if got != 1 {
t.Errorf("got %d, want 1", got)
}
_, ok = mock.Memory().ReadByte(65536)
if !ok {
t.Error("want ok")
}
}

View File

@@ -9,7 +9,7 @@ import (
"sync"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
"github.com/ncruces/go-sqlite3/vfs"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
)
@@ -23,6 +23,7 @@ import (
var (
Binary []byte // WASM binary to load.
Path string // Path to load the binary from.
Debug bool // Whether to enable SQLite debug stack traces.
)
var sqlite3 struct {
@@ -51,9 +52,9 @@ func instantiateModule() (*module, error) {
func compileModule() {
ctx := context.Background()
sqlite3.runtime = wazero.NewRuntime(ctx)
sqlite3.runtime = wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfig().WithDebugInfoEnabled(Debug))
env := sqlite3vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
env := vfs.ExportHostFunctions(sqlite3.runtime.NewHostModuleBuilder("env"))
_, sqlite3.err = env.Instantiate(ctx)
if sqlite3.err != nil {
return
@@ -83,9 +84,9 @@ type module struct {
}
func newModule(mod api.Module) (m *module, err error) {
m = &module{}
m = new(module)
m.mod = mod
m.ctx, m.vfs = sqlite3vfs.NewContext(context.Background())
m.ctx, m.vfs = vfs.NewContext(context.Background())
getFun := func(name string) api.Function {
f := mod.ExportedFunction(name)
@@ -139,9 +140,6 @@ func newModule(mod api.Module) (m *module, err error) {
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
@@ -153,6 +151,9 @@ func newModule(mod api.Module) (m *module, err error) {
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
}
if err != nil {
return nil, err
@@ -177,22 +178,17 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
panic(util.OOMErr)
}
var r []uint64
r = m.call(m.api.errstr, rc)
if r != nil {
err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
if r := m.call(m.api.errstr, rc); r != 0 {
err.str = util.ReadString(m.mod, uint32(r), _MAX_STRING)
}
r = m.call(m.api.errmsg, uint64(handle))
if r != nil {
err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING)
if r := m.call(m.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(m.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
r = m.call(m.api.erroff, uint64(handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
if r := m.call(m.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
@@ -203,7 +199,7 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error {
return &err
}
func (m *module) call(fn api.Function, params ...uint64) []uint64 {
func (m *module) call(fn api.Function, params ...uint64) uint64 {
copy(m.arg[:], params)
err := fn.CallWithStack(m.ctx, m.arg[:])
if err != nil {
@@ -211,7 +207,7 @@ func (m *module) call(fn api.Function, params ...uint64) []uint64 {
m.vfs.Close()
panic(err)
}
return m.arg[:]
return m.arg[0]
}
func (m *module) free(ptr uint32) {
@@ -225,8 +221,7 @@ func (m *module) new(size uint64) uint32 {
if size > _MAX_ALLOCATION_SIZE {
panic(util.OOMErr)
}
r := m.call(m.api.malloc, size)
ptr := uint32(r[0])
ptr := uint32(m.call(m.api.malloc, size))
if ptr == 0 && size != 0 {
panic(util.OOMErr)
}
@@ -340,9 +335,6 @@ type sqliteAPI struct {
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
@@ -354,5 +346,8 @@ type sqliteAPI struct {
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
changes api.Function
lastRowid api.Function
autocommit api.Function
destructor uint32
}

View File

@@ -76,7 +76,6 @@ func TestConn_newArena(t *testing.T) {
defer arena.free()
const title = "Lorem ipsum"
ptr := arena.string(title)
if ptr == 0 {
t.Fatalf("got nullptr")
@@ -93,6 +92,19 @@ func TestConn_newArena(t *testing.T) {
if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body {
t.Errorf("got %q, want %q", got, body)
}
ptr = arena.bytes(nil)
if ptr != 0 {
t.Errorf("want nullptr")
}
ptr = arena.bytes([]byte(title))
if ptr == 0 {
t.Fatalf("got nullptr")
}
if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title {
t.Errorf("got %q, want %q", got, title)
}
arena.free()
}

20
sqlite3/deserialize.patch Normal file
View File

@@ -0,0 +1,20 @@
--- sqlite3.c.orig
+++ sqlite3.c
@@ -60425,7 +60425,7 @@
int rc = SQLITE_OK; /* Return code */
int tempFile = 0; /* True for temp files (incl. in-memory files) */
int memDb = 0; /* True if this is an in-memory file */
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
int memJM = 0; /* Memory journal mode */
#else
# define memJM 0
@@ -60628,7 +60628,7 @@
int fout = 0; /* VFS flags returned by xOpen() */
rc = sqlite3OsOpen(pVfs, pPager->zFilename, pPager->fd, vfsFlags, &fout);
assert( !memDb );
-#ifndef SQLITE_OMIT_DESERIALIZE
+#if 1
pPager->memVfs = memJM = (fout&SQLITE_OPEN_MEMORY)!=0;
#endif
readOnly = (fout&SQLITE_OPEN_READONLY)!=0;

View File

@@ -8,6 +8,9 @@ unzip -d . sqlite-amalgamation-*.zip
mv sqlite-amalgamation-*/sqlite3* .
rm -rf sqlite-amalgamation-*
patch < vfs_find.patch
patch < deserialize.patch
cd ext/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/decimal.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/uint.c"
@@ -17,7 +20,7 @@ curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/regexp.c
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/ext/misc/series.c"
cd ~-
cd ../sqlite3vfs/tests/mptest/testdata/
cd ../vfs/tests/mptest/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/mptest.c"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config01.test"
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/config02.test"
@@ -26,6 +29,6 @@ curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/crash02.su
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/mptest/multiwrite01.test"
cd ~-
cd ../sqlite3vfs/tests/speedtest1/testdata/
cd ../vfs/tests/speedtest1/testdata/
curl -#OL "https://github.com/sqlite/sqlite/raw/version-3.42.0/test/speedtest1.c"
cd ~-

View File

@@ -30,6 +30,9 @@
// Other Options
#define SQLITE_ALLOW_URI_AUTHORITY
#define SQLITE_ENABLE_BATCH_ATOMIC_WRITE
#define SQLITE_ENABLE_ATOMIC_WRITE
#define SQLITE_OMIT_DESERIALIZE
// Because WASM does not support shared memory,
// SQLite disables WAL for WASM builds.
@@ -55,16 +58,5 @@
// #define SQLITE_ENABLE_SESSION
// #define SQLITE_ENABLE_PREUPDATE_HOOK
// https://stackoverflow.com/a/50616684
#define SECOND(...) SECOND_I(__VA_ARGS__, , )
#define SECOND_I(A, B, ...) B
#define GLUE(A, B) GLUE_I(A, B)
#define GLUE_I(A, B) A##_##B
#define CREATE_REPLACER(A) SECOND(GLUE(A, __LINE__), A)
#define REPLACE_AT_LINE(A) , A
// Implemented in vfs.c.
int localtime_s(struct tm *const pTm, time_t const *const pTime);
#define sqlite3_vfs_find CREATE_REPLACER(sqlite3_vfs_find_wrapper)
#define sqlite3_vfs_find_wrapper_25397 REPLACE_AT_LINE(sqlite3_vfs_find)
int localtime_s(struct tm *const pTm, time_t const *const pTime);

View File

@@ -87,8 +87,7 @@ int localtime_s(struct tm *const pTm, time_t const *const pTime) {
return go_localtime(pTm, (sqlite3_int64)*pTime);
}
#undef sqlite3_vfs_find
sqlite3_vfs *sqlite3_vfs_find_wrapper(const char *zVfsName) {
sqlite3_vfs *sqlite3_vfs_find(const char *zVfsName) {
if (zVfsName) {
static sqlite3_vfs *go_vfs_list;
sqlite3_vfs *found = NULL;
@@ -130,7 +129,7 @@ sqlite3_vfs *sqlite3_vfs_find_wrapper(const char *zVfsName) {
return go_vfs_list;
}
}
return sqlite3_vfs_find(zVfsName);
return sqlite3_vfs_find_orig(zVfsName);
}
static_assert(offsetof(struct go_file, handle) == 4, "Unexpected offset");

11
sqlite3/vfs_find.patch Normal file
View File

@@ -0,0 +1,11 @@
--- sqlite3.c.orig
+++ sqlite3.c
@@ -25394,7 +25394,7 @@
** Locate a VFS by name. If no name is given, simply return the
** first VFS on the list.
*/
-SQLITE_API sqlite3_vfs *sqlite3_vfs_find(const char *zVfs){
+SQLITE_API sqlite3_vfs *sqlite3_vfs_find_orig(const char *zVfs){
sqlite3_vfs *pVfs = 0;
#if SQLITE_THREADSAFE
sqlite3_mutex *mutex;

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aafb873c26d32dbf7498e4c4dd243eb6411c056432818121463d8cb49bc15f70
size 1479293

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2638f5458f93b60f97006456f60e3e294cffb9ac22b20dba748fbd12e34f056d
size 1515841

54
stmt.go
View File

@@ -29,7 +29,7 @@ func (s *Stmt) Close() error {
r := s.c.call(s.c.api.finalize, uint64(s.handle))
s.handle = 0
return s.c.error(r[0])
return s.c.error(r)
}
// Reset resets the prepared statement object.
@@ -38,7 +38,7 @@ func (s *Stmt) Close() error {
func (s *Stmt) Reset() error {
r := s.c.call(s.c.api.reset, uint64(s.handle))
s.err = nil
return s.c.error(r[0])
return s.c.error(r)
}
// ClearBindings resets all bindings on the prepared statement.
@@ -46,7 +46,7 @@ func (s *Stmt) Reset() error {
// https://www.sqlite.org/c3ref/clear_bindings.html
func (s *Stmt) ClearBindings() error {
r := s.c.call(s.c.api.clearBindings, uint64(s.handle))
return s.c.error(r[0])
return s.c.error(r)
}
// Step evaluates the SQL statement.
@@ -61,13 +61,13 @@ func (s *Stmt) ClearBindings() error {
func (s *Stmt) Step() bool {
s.c.checkInterrupt()
r := s.c.call(s.c.api.step, uint64(s.handle))
if r[0] == _ROW {
if r == _ROW {
return true
}
if r[0] == _DONE {
if r == _DONE {
s.err = nil
} else {
s.err = s.c.error(r[0])
s.err = s.c.error(r)
}
return false
}
@@ -94,7 +94,7 @@ func (s *Stmt) Exec() error {
func (s *Stmt) BindCount() int {
r := s.c.call(s.c.api.bindCount,
uint64(s.handle))
return int(r[0])
return int(r)
}
// BindIndex returns the index of a parameter in the prepared statement
@@ -106,7 +106,7 @@ func (s *Stmt) BindIndex(name string) int {
namePtr := s.c.arena.string(name)
r := s.c.call(s.c.api.bindIndex,
uint64(s.handle), uint64(namePtr))
return int(r[0])
return int(r)
}
// BindName returns the name of a parameter in the prepared statement.
@@ -117,7 +117,7 @@ func (s *Stmt) BindName(param int) string {
r := s.c.call(s.c.api.bindName,
uint64(s.handle), uint64(param))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
return ""
}
@@ -152,7 +152,7 @@ func (s *Stmt) BindInt(param int, value int) error {
func (s *Stmt) BindInt64(param int, value int64) error {
r := s.c.call(s.c.api.bindInteger,
uint64(s.handle), uint64(param), uint64(value))
return s.c.error(r[0])
return s.c.error(r)
}
// BindFloat binds a float64 to the prepared statement.
@@ -162,7 +162,7 @@ func (s *Stmt) BindInt64(param int, value int64) error {
func (s *Stmt) BindFloat(param int, value float64) error {
r := s.c.call(s.c.api.bindFloat,
uint64(s.handle), uint64(param), math.Float64bits(value))
return s.c.error(r[0])
return s.c.error(r)
}
// BindText binds a string to the prepared statement.
@@ -175,7 +175,7 @@ func (s *Stmt) BindText(param int, value string) error {
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
return s.c.error(r)
}
// BindBlob binds a []byte to the prepared statement.
@@ -189,7 +189,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(value)),
uint64(s.c.api.destructor))
return s.c.error(r[0])
return s.c.error(r)
}
// BindZeroBlob binds a zero-filled, length n BLOB to the prepared statement.
@@ -199,7 +199,7 @@ func (s *Stmt) BindBlob(param int, value []byte) error {
func (s *Stmt) BindZeroBlob(param int, n int64) error {
r := s.c.call(s.c.api.bindZeroBlob,
uint64(s.handle), uint64(param), uint64(n))
return s.c.error(r[0])
return s.c.error(r)
}
// BindNull binds a NULL to the prepared statement.
@@ -209,7 +209,7 @@ func (s *Stmt) BindZeroBlob(param int, n int64) error {
func (s *Stmt) BindNull(param int) error {
r := s.c.call(s.c.api.bindNull,
uint64(s.handle), uint64(param))
return s.c.error(r[0])
return s.c.error(r)
}
// BindTime binds a [time.Time] to the prepared statement.
@@ -244,7 +244,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(buf)),
uint64(s.c.api.destructor), _UTF8)
return s.c.error(r[0])
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
@@ -253,7 +253,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
func (s *Stmt) ColumnCount() int {
r := s.c.call(s.c.api.columnCount,
uint64(s.handle))
return int(r[0])
return int(r)
}
// ColumnName returns the name of the result column.
@@ -264,7 +264,7 @@ func (s *Stmt) ColumnName(col int) string {
r := s.c.call(s.c.api.columnName,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
panic(util.OOMErr)
}
@@ -278,7 +278,7 @@ func (s *Stmt) ColumnName(col int) string {
func (s *Stmt) ColumnType(col int) Datatype {
r := s.c.call(s.c.api.columnType,
uint64(s.handle), uint64(col))
return Datatype(r[0])
return Datatype(r)
}
// ColumnBool returns the value of the result column as a bool.
@@ -310,7 +310,7 @@ func (s *Stmt) ColumnInt(col int) int {
func (s *Stmt) ColumnInt64(col int) int64 {
r := s.c.call(s.c.api.columnInteger,
uint64(s.handle), uint64(col))
return int64(r[0])
return int64(r)
}
// ColumnFloat returns the value of the result column as a float64.
@@ -320,7 +320,7 @@ func (s *Stmt) ColumnInt64(col int) int64 {
func (s *Stmt) ColumnFloat(col int) float64 {
r := s.c.call(s.c.api.columnFloat,
uint64(s.handle), uint64(col))
return math.Float64frombits(r[0])
return math.Float64frombits(r)
}
// ColumnTime returns the value of the result column as a [time.Time].
@@ -375,17 +375,17 @@ func (s *Stmt) ColumnRawText(col int) []byte {
r := s.c.call(s.c.api.columnText,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r[0])
return util.View(s.c.mod, ptr, r)
}
// ColumnRawBlob returns the value of the result column as a []byte.
@@ -398,17 +398,17 @@ func (s *Stmt) ColumnRawBlob(col int) []byte {
r := s.c.call(s.c.api.columnBlob,
uint64(s.handle), uint64(col))
ptr := uint32(r[0])
ptr := uint32(r)
if ptr == 0 {
r = s.c.call(s.c.api.errcode, uint64(s.c.handle))
s.err = s.c.error(r[0])
s.err = s.c.error(r)
return nil
}
r = s.c.call(s.c.api.columnBytes,
uint64(s.handle), uint64(col))
return util.View(s.c.mod, ptr, r[0])
return util.View(s.c.mod, ptr, r)
}
// Return true if stmt is an empty SQL statement.

View File

@@ -124,4 +124,46 @@ func TestBackup(t *testing.T) {
t.Fatal(err)
}
}()
func() { // Incremental.
db, err := sqlite3.Open(backupName)
if err != nil {
t.Fatal(err)
}
defer db.Close()
b, err := db.BackupInit("main", ":memory:")
if err != nil {
t.Fatal(err)
}
defer b.Close()
done, err := b.Step(1)
if done {
t.Error("want false")
}
if err != nil {
t.Error(err)
}
n := b.Remaining()
if n != 1 {
t.Errorf("got %d", n)
}
n = b.PageCount()
if n != 2 {
t.Errorf("got %d", n)
}
err = b.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}()
}

View File

@@ -3,6 +3,8 @@ package tests
import (
"context"
"errors"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
@@ -23,6 +25,55 @@ func TestConn_Open_dir(t *testing.T) {
}
}
func TestConn_Open_notfound(t *testing.T) {
t.Parallel()
_, err := sqlite3.OpenFlags("test.db", sqlite3.OPEN_READONLY)
if err == nil {
t.Fatal("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func TestConn_Open_modeof(t *testing.T) {
t.Parallel()
dir := t.TempDir()
file := filepath.Join(dir, "test.db")
mode := filepath.Join(dir, "modeof.txt")
fd, err := os.OpenFile(mode, os.O_CREATE, 0624)
if err != nil {
t.Fatal(err)
}
fi, err := fd.Stat()
if err != nil {
t.Fatal(err)
}
fd.Close()
db, err := sqlite3.Open("file:" + file + "?modeof=" + mode)
if err != nil {
t.Fatal(err)
}
di, err := os.Stat(file)
if err != nil {
t.Fatal(err)
}
db.Close()
if di.Mode() != fi.Mode() {
t.Errorf("got %v, want %v", di.Mode(), fi.Mode())
}
_, err = sqlite3.Open("file:" + file + "?modeof=" + mode + "2")
if err == nil {
t.Fatal("want error")
}
}
func TestConn_Close(t *testing.T) {
var conn *sqlite3.Conn
conn.Close()

View File

@@ -6,19 +6,24 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestDB_memory(t *testing.T) {
t.Parallel()
testDB(t, ":memory:")
}
func TestDB_file(t *testing.T) {
t.Parallel()
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func testDB(t *testing.T, name string) {
t.Parallel()
func TestDB_vfs(t *testing.T) {
testDB(t, "file:test.db?vfs=memdb")
}
func testDB(t *testing.T, name string) {
db, err := sqlite3.Open(name)
if err != nil {
t.Fatal(err)

View File

@@ -27,18 +27,32 @@ func TestDriver(t *testing.T) {
}
defer conn.Close()
_, err = conn.ExecContext(ctx,
res, err := conn.ExecContext(ctx,
`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
changes, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if changes != 0 {
t.Errorf("got %d want 0", changes)
}
id, err := res.LastInsertId()
if err != nil {
t.Fatal(err)
}
if id != 0 {
t.Errorf("got %d want 0", changes)
}
res, err := conn.ExecContext(ctx,
res, err = conn.ExecContext(ctx,
`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
changes, err := res.RowsAffected()
changes, err = res.RowsAffected()
if err != nil {
t.Fatal(err)
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
func TestParallel(t *testing.T) {
@@ -21,7 +22,29 @@ func TestParallel(t *testing.T) {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
name := "file:" +
filepath.Join(t.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
testIntegrity(t, name)
}
func TestMemory(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
name := "file:/test.db?vfs=memdb" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -31,8 +54,14 @@ func TestMultiProcess(t *testing.T) {
t.Skip("skipping in short mode")
}
name := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbname", name)
file := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbfile", file)
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
out, err := cmd.StdoutPipe()
@@ -57,11 +86,17 @@ func TestMultiProcess(t *testing.T) {
}
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
file := os.Getenv("TestMultiProcess_dbfile")
if file == "" || testing.Short() {
t.SkipNow()
}
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, 1000)
}
@@ -73,16 +108,6 @@ func testParallel(t *testing.T, name string, n int) {
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA synchronous=off;
PRAGMA locking_mode=normal;
PRAGMA journal_mode=truncate;
`)
if err != nil {
return err
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err

36
tests/vfs_test.go Normal file
View File

@@ -0,0 +1,36 @@
package tests
import (
"errors"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/ncruces/go-sqlite3/vfs/readervfs"
)
func TestMemoryVFS_Open_notfound(t *testing.T) {
memdb.Delete("demo.db")
_, err := sqlite3.Open("file:/demo.db?vfs=memdb&mode=ro")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}
func TestReaderVFS_Open_notfound(t *testing.T) {
readervfs.Delete("demo.db")
_, err := sqlite3.Open("file:demo.db?vfs=reader")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
}

9
vfs/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Go SQLite VFS API
This package implements the SQLite [OS Interface](https://www.sqlite.org/vfs.html) (aka VFS).
It replaces the default VFS with a pure Go implementation,
that is tested on Linux, macOS and Windows,
but which should also work on illumos and the various BSDs.
It also exposes interfaces that should allow you to implement your own custom VFSes.

View File

@@ -1,7 +1,7 @@
// Package sqlite3vfs wraps the C SQLite VFS API.
package sqlite3vfs
// Package vfs wraps the C SQLite VFS API.
package vfs
import "sync"
import "net/url"
// A VFS defines the interface between the SQLite core and the underlying operating system.
//
@@ -15,6 +15,15 @@ type VFS interface {
FullPathname(name string) (string, error)
}
// VFSParams extends VFS to with the ability to handle URI parameters
// through the OpenParams method.
//
// https://www.sqlite.org/c3ref/uri_boolean.html
type VFSParams interface {
VFS
OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error)
}
// A File represents an open file in the OS interface layer.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.
@@ -27,7 +36,7 @@ type File interface {
WriteAt(p []byte, off int64) (n int, err error)
Truncate(size int64) error
Sync(flags SyncFlag) error
FileSize() (int64, error)
Size() (int64, error)
Lock(lock LockLevel) error
Unlock(lock LockLevel) error
CheckReservedLock() (bool, error)
@@ -35,7 +44,7 @@ type File interface {
DeviceCharacteristics() DeviceCharacteristic
}
// FileLockState extends [File] to implement the
// FileLockState extends File to implement the
// SQLITE_FCNTL_LOCKSTATE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
@@ -44,7 +53,7 @@ type FileLockState interface {
LockState() LockLevel
}
// FileSizeHint extends [File] to implement the
// FileSizeHint extends File to implement the
// SQLITE_FCNTL_SIZE_HINT file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
@@ -53,7 +62,7 @@ type FileSizeHint interface {
SizeHint(size int64) error
}
// FileHasMoved extends [File] to implement the
// FileHasMoved extends File to implement the
// SQLITE_FCNTL_HAS_MOVED file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
@@ -62,7 +71,7 @@ type FileHasMoved interface {
HasMoved() (bool, error)
}
// FilePowersafeOverwrite extends [File] to implement the
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_POWERSAFE_OVERWRITE file control opcode.
//
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
@@ -72,38 +81,23 @@ type FilePowersafeOverwrite interface {
SetPowersafeOverwrite(bool)
}
var (
vfsRegistry map[string]VFS
vfsRegistryMtx sync.Mutex
)
// Find returns a VFS given its name.
// If there is no match, nil is returned.
// FilePowersafeOverwrite extends File to implement the
// SQLITE_FCNTL_COMMIT_PHASETWO file control opcode.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Find(name string) VFS {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
return vfsRegistry[name]
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileCommitPhaseTwo interface {
File
CommitPhaseTwo() error
}
// Register registers a VFS.
// FileBatchAtomicWrite extends File to implement the
// SQLITE_FCNTL_BEGIN_ATOMIC_WRITE, SQLITE_FCNTL_COMMIT_ATOMIC_WRITE
// and SQLITE_FCNTL_ROLLBACK_ATOMIC_WRITE file control opcodes.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Register(name string, vfs VFS) {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
if vfsRegistry == nil {
vfsRegistry = map[string]VFS{}
}
vfsRegistry[name] = vfs
}
// Unregister unregisters a VFS.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Unregister(name string) {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
delete(vfsRegistry, name)
// https://www.sqlite.org/c3ref/c_fcntl_begin_atomic_write.html
type FileBatchAtomicWrite interface {
File
BeginAtomicWrite() error
CommitAtomicWrite() error
RollbackAtomicWrite() error
}

View File

@@ -1,4 +1,4 @@
package sqlite3vfs
package vfs
import "github.com/ncruces/go-sqlite3/internal/util"
@@ -19,6 +19,7 @@ const (
_OK _ErrorCode = util.OK
_PERM _ErrorCode = util.PERM
_BUSY _ErrorCode = util.BUSY
_READONLY _ErrorCode = util.READONLY
_IOERR _ErrorCode = util.IOERR
_NOTFOUND _ErrorCode = util.NOTFOUND
_CANTOPEN _ErrorCode = util.CANTOPEN
@@ -38,7 +39,11 @@ const (
_IOERR_CLOSE _ErrorCode = util.IOERR_CLOSE
_IOERR_SEEK _ErrorCode = util.IOERR_SEEK
_IOERR_DELETE_NOENT _ErrorCode = util.IOERR_DELETE_NOENT
_IOERR_BEGIN_ATOMIC _ErrorCode = util.IOERR_BEGIN_ATOMIC
_IOERR_COMMIT_ATOMIC _ErrorCode = util.IOERR_COMMIT_ATOMIC
_IOERR_ROLLBACK_ATOMIC _ErrorCode = util.IOERR_ROLLBACK_ATOMIC
_CANTOPEN_FULLPATH _ErrorCode = util.CANTOPEN_FULLPATH
_CANTOPEN_ISDIR _ErrorCode = util.CANTOPEN_ISDIR
_OK_SYMLINK _ErrorCode = util.OK_SYMLINK
)

View File

@@ -1,17 +1,15 @@
package sqlite3vfs
package vfs
import (
"context"
"errors"
"io"
"io/fs"
"net/url"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
type vfsOS struct{}
@@ -36,10 +34,10 @@ func (vfsOS) FullPathname(path string) (string, error) {
func (vfsOS) Delete(path string, syncDir bool) error {
err := os.Remove(path)
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return _IOERR_DELETE_NOENT
}
return err
}
if runtime.GOOS != "windows" && syncDir {
@@ -71,6 +69,10 @@ func (vfsOS) Access(name string, flags AccessFlag) (bool, error) {
}
func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
return vfsOS{}.OpenParams(name, flags, nil)
}
func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error) {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
@@ -93,9 +95,18 @@ func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
f, err = osOpenFile(name, oflags, 0666)
}
if err != nil {
if errors.Is(err, syscall.EISDIR) {
return nil, flags, _CANTOPEN_ISDIR
}
return nil, flags, err
}
if modeof := params.Get("modeof"); modeof != "" {
if err = osSetMode(f, modeof); err != nil {
f.Close()
return nil, flags, _IOERR_FSTAT
}
}
if flags&OPEN_DELETEONCLOSE != 0 {
os.Remove(f.Name())
}
@@ -128,42 +139,6 @@ var (
_ FilePowersafeOverwrite = &vfsFile{}
)
func vfsFileNew(vfs *vfsState, file File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) {
const fileHandleOffset = 4
id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file)
util.WriteUint32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
return vfs.files[id]
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
}
func (f *vfsFile) Sync(flags SyncFlag) error {
dataonly := (flags & SYNC_DATAONLY) != 0
fullsync := (flags & 0x0f) == SYNC_FULL
@@ -187,7 +162,7 @@ func (f *vfsFile) Sync(flags SyncFlag) error {
return nil
}
func (f *vfsFile) FileSize() (int64, error) {
func (f *vfsFile) Size() (int64, error) {
return f.Seek(0, io.SeekEnd)
}
@@ -212,7 +187,10 @@ func (f *vfsFile) HasMoved() (bool, error) {
return false, err
}
pi, err := os.Stat(f.Name())
if err != nil && !errors.Is(err, fs.ErrNotExist) {
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return true, nil
}
return false, err
}
return !os.SameFile(fi, pi), nil

View File

@@ -1,4 +1,4 @@
package sqlite3vfs
package vfs
import (
"os"
@@ -14,73 +14,73 @@ const (
_SHARED_SIZE = 510
)
func (file *vfsFile) Lock(eLock LockLevel) error {
func (f *vfsFile) Lock(lock LockLevel) error {
// Argument check. SQLite never explicitly requests a pending lock.
if eLock != LOCK_SHARED && eLock != LOCK_RESERVED && eLock != LOCK_EXCLUSIVE {
if lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
switch {
case file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE:
case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE:
// Connection state check.
panic(util.AssertErr())
case file.lock == LOCK_NONE && eLock > LOCK_SHARED:
case f.lock == LOCK_NONE && lock > LOCK_SHARED:
// We never move from unlocked to anything higher than a shared lock.
panic(util.AssertErr())
case file.lock != LOCK_SHARED && eLock == LOCK_RESERVED:
case f.lock != LOCK_SHARED && lock == LOCK_RESERVED:
// A shared lock is always held when a reserved lock is requested.
panic(util.AssertErr())
}
// If we already have an equal or more restrictive lock, do nothing.
if file.lock >= eLock {
if f.lock >= lock {
return nil
}
// Do not allow any kind of write-lock on a read-only database.
if file.readOnly && eLock >= LOCK_RESERVED {
if f.readOnly && lock >= LOCK_RESERVED {
return _IOERR_LOCK
}
switch eLock {
switch lock {
case LOCK_SHARED:
// Must be unlocked to get SHARED.
if file.lock != LOCK_NONE {
if f.lock != LOCK_NONE {
panic(util.AssertErr())
}
if rc := osGetSharedLock(file.File, file.lockTimeout); rc != _OK {
if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK {
return rc
}
file.lock = LOCK_SHARED
f.lock = LOCK_SHARED
return nil
case LOCK_RESERVED:
// Must be SHARED to get RESERVED.
if file.lock != LOCK_SHARED {
if f.lock != LOCK_SHARED {
panic(util.AssertErr())
}
if rc := osGetReservedLock(file.File, file.lockTimeout); rc != _OK {
if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK {
return rc
}
file.lock = LOCK_RESERVED
f.lock = LOCK_RESERVED
return nil
case LOCK_EXCLUSIVE:
// Must be SHARED, RESERVED or PENDING to get EXCLUSIVE.
if file.lock <= LOCK_NONE || file.lock >= LOCK_EXCLUSIVE {
if f.lock <= LOCK_NONE || f.lock >= LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// A PENDING lock is needed before acquiring an EXCLUSIVE lock.
if file.lock < LOCK_PENDING {
if rc := osGetPendingLock(file.File); rc != _OK {
if f.lock < LOCK_PENDING {
if rc := osGetPendingLock(f.File); rc != _OK {
return rc
}
file.lock = LOCK_PENDING
f.lock = LOCK_PENDING
}
if rc := osGetExclusiveLock(file.File, file.lockTimeout); rc != _OK {
if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK {
return rc
}
file.lock = LOCK_EXCLUSIVE
f.lock = LOCK_EXCLUSIVE
return nil
default:
@@ -88,33 +88,33 @@ func (file *vfsFile) Lock(eLock LockLevel) error {
}
}
func (file *vfsFile) Unlock(eLock LockLevel) error {
func (f *vfsFile) Unlock(lock LockLevel) error {
// Argument check.
if eLock != LOCK_NONE && eLock != LOCK_SHARED {
if lock != LOCK_NONE && lock != LOCK_SHARED {
panic(util.AssertErr())
}
// Connection state check.
if file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE {
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
// If we don't have a more restrictive lock, do nothing.
if file.lock <= eLock {
if f.lock <= lock {
return nil
}
switch eLock {
switch lock {
case LOCK_SHARED:
if rc := osDowngradeLock(file.File, file.lock); rc != _OK {
if rc := osDowngradeLock(f.File, f.lock); rc != _OK {
return rc
}
file.lock = LOCK_SHARED
f.lock = LOCK_SHARED
return nil
case LOCK_NONE:
rc := osReleaseLock(file.File, file.lock)
file.lock = LOCK_NONE
rc := osReleaseLock(f.File, f.lock)
f.lock = LOCK_NONE
return rc
default:
@@ -122,16 +122,16 @@ func (file *vfsFile) Unlock(eLock LockLevel) error {
}
}
func (file *vfsFile) CheckReservedLock() (bool, error) {
func (f *vfsFile) CheckReservedLock() (bool, error) {
// Connection state check.
if file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE {
if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE {
panic(util.AssertErr())
}
if file.lock >= LOCK_RESERVED {
if f.lock >= LOCK_RESERVED {
return true, nil
}
return osCheckReservedLock(file.File)
return osCheckReservedLock(f.File)
}
func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode {

View File

@@ -1,4 +1,4 @@
package sqlite3vfs
package vfs
import (
"context"
@@ -8,6 +8,7 @@ import (
"testing"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/experimental/wazerotest"
)
func Test_vfsLock(t *testing.T) {
@@ -39,7 +40,7 @@ func Test_vfsLock(t *testing.T) {
pFile2 = 16
pOutput = 32
)
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()

9
vfs/memdb/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Go `"memdb"` SQLite VFS
This package implements the [`"memdb"`](https://www.sqlite.org/src/file/src/memdb.c)
SQLite VFS in pure Go.
It has some benefits over the C version:
- the memory backing the database needs not be contiguous,
- the database can grow/shrink incrementally without copying,
- reader-writer concurrency is slightly improved.

59
vfs/memdb/api.go Normal file
View File

@@ -0,0 +1,59 @@
// Package memdb implements the "memdb" SQLite VFS.
//
// The "memdb" [vfs.VFS] allows the same in-memory database to be shared
// among multiple database connections in the same process,
// as long as the database name begins with "/".
//
// Importing package memdb registers the VFS.
//
// import _ "github.com/ncruces/go-sqlite3/vfs/memdb"
package memdb
import (
"sync"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("memdb", memVFS{})
}
var (
memoryMtx sync.Mutex
// +checklocks:memoryMtx
memoryDBs = map[string]*memDB{}
)
// Create creates a shared memory database,
// using data as its initial contents.
// The new database takes ownership of data,
// and the caller should not use data after this call.
func Create(name string, data []byte) {
memoryMtx.Lock()
defer memoryMtx.Unlock()
db := new(memDB)
db.size = int64(len(data))
sectors := divRoundUp(db.size, sectorSize)
db.data = make([]*[sectorSize]byte, sectors)
for i := range db.data {
sector := data[i*sectorSize:]
if len(sector) >= sectorSize {
db.data[i] = (*[sectorSize]byte)(sector)
} else {
db.data[i] = new([sectorSize]byte)
copy((*db.data[i])[:], sector)
}
}
memoryDBs[name] = db
}
// Delete deletes a shared memory database.
func Delete(name string) {
memoryMtx.Lock()
defer memoryMtx.Unlock()
delete(memoryDBs, name)
}

51
vfs/memdb/example_test.go Normal file
View File

@@ -0,0 +1,51 @@
package memdb_test
import (
"database/sql"
"fmt"
"log"
_ "embed"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/test.db
var testDB []byte
func Example() {
memdb.Create("test.db", testDB)
db, err := sql.Open("sqlite3", "file:/test.db?vfs=memdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`)
if err != nil {
log.Fatal(err)
}
rows, err := db.Query(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id, name string
err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s %s\n", id, name)
}
// Output:
// 0 go
// 1 zig
// 2 whatever
// 3 rust
}

294
vfs/memdb/memdb.go Normal file
View File

@@ -0,0 +1,294 @@
package memdb
import (
"io"
"runtime"
"strings"
"sync"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
type memVFS struct{}
func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
// Allowed file types:
// - databases, which only do page aligned reads/writes;
// - temp journals, used by the sorter, which does the same.
const types = vfs.OPEN_MAIN_DB |
vfs.OPEN_TRANSIENT_DB |
vfs.OPEN_TEMP_DB |
vfs.OPEN_TEMP_JOURNAL
if flags&types == 0 {
return nil, flags, sqlite3.CANTOPEN
}
var db *memDB
shared := strings.HasPrefix(name, "/")
if shared {
memoryMtx.Lock()
defer memoryMtx.Unlock()
db = memoryDBs[name[1:]]
}
if db == nil {
if flags&vfs.OPEN_CREATE == 0 {
return nil, flags, sqlite3.CANTOPEN
}
db = new(memDB)
}
if shared {
memoryDBs[name[1:]] = db // +checklocksignore: lock is held
}
return &memFile{
memDB: db,
readOnly: flags&vfs.OPEN_READONLY != 0,
}, flags | vfs.OPEN_MEMORY, nil
}
func (memVFS) Delete(name string, dirSync bool) error {
return sqlite3.IOERR_DELETE
}
func (memVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return false, nil
}
func (memVFS) FullPathname(name string) (string, error) {
return name, nil
}
// Must be a multiple of 64K (the largest page size).
const sectorSize = 65536
type memDB struct {
// +checklocks:lockMtx
pending *memFile
// +checklocks:lockMtx
reserved *memFile
// +checklocks:dataMtx
data []*[sectorSize]byte
// +checklocks:dataMtx
size int64
// +checklocks:lockMtx
shared int
lockMtx sync.Mutex
dataMtx sync.RWMutex
}
type memFile struct {
*memDB
lock vfs.LockLevel
readOnly bool
}
var (
// Ensure these interfaces are implemented:
_ vfs.FileLockState = &memFile{}
_ vfs.FileSizeHint = &memFile{}
)
func (m *memFile) Close() error {
return m.Unlock(vfs.LOCK_NONE)
}
func (m *memFile) ReadAt(b []byte, off int64) (n int, err error) {
m.dataMtx.RLock()
defer m.dataMtx.RUnlock()
if off >= m.size {
return 0, io.EOF
}
base := off / sectorSize
rest := off % sectorSize
have := int64(sectorSize)
if base == int64(len(m.data))-1 {
have = modRoundUp(m.size, sectorSize)
}
n = copy(b, (*m.data[base])[rest:have])
if n < len(b) {
// Assume reads are page aligned.
return 0, io.ErrNoProgress
}
return n, nil
}
func (m *memFile) WriteAt(b []byte, off int64) (n int, err error) {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
base := off / sectorSize
rest := off % sectorSize
for base >= int64(len(m.data)) {
m.data = append(m.data, new([sectorSize]byte))
}
n = copy((*m.data[base])[rest:], b)
if n < len(b) {
// Assume writes are page aligned.
return 0, io.ErrShortWrite
}
if size := off + int64(len(b)); size > m.size {
m.size = size
}
return n, nil
}
func (m *memFile) Truncate(size int64) error {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
return m.truncate(size)
}
// +checklocks:m.dataMtx
func (m *memFile) truncate(size int64) error {
if size < m.size {
base := size / sectorSize
rest := size % sectorSize
if rest != 0 {
clear((*m.data[base])[rest:])
}
}
sectors := divRoundUp(size, sectorSize)
for sectors > int64(len(m.data)) {
m.data = append(m.data, new([sectorSize]byte))
}
clear(m.data[sectors:])
m.data = m.data[:sectors]
m.size = size
return nil
}
func (*memFile) Sync(flag vfs.SyncFlag) error {
return nil
}
func (m *memFile) Size() (int64, error) {
m.dataMtx.RLock()
defer m.dataMtx.RUnlock()
return m.size, nil
}
func (m *memFile) Lock(lock vfs.LockLevel) error {
if m.lock >= lock {
return nil
}
if m.readOnly && lock >= vfs.LOCK_RESERVED {
return sqlite3.IOERR_LOCK
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
switch lock {
case vfs.LOCK_SHARED:
if m.pending != nil {
return sqlite3.BUSY
}
m.shared++
case vfs.LOCK_RESERVED:
if m.reserved != nil {
return sqlite3.BUSY
}
m.reserved = m
case vfs.LOCK_EXCLUSIVE:
if m.lock < vfs.LOCK_PENDING {
if m.pending != nil {
return sqlite3.BUSY
}
m.lock = vfs.LOCK_PENDING
m.pending = m
}
for start := time.Now(); m.shared > 1; {
if time.Since(start) > time.Millisecond {
return sqlite3.BUSY
}
m.lockMtx.Unlock()
runtime.Gosched()
m.lockMtx.Lock()
}
}
m.lock = lock
return nil
}
func (m *memFile) Unlock(lock vfs.LockLevel) error {
if m.lock <= lock {
return nil
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
if m.pending == m {
m.pending = nil
}
if m.reserved == m {
m.reserved = nil
}
if lock < vfs.LOCK_SHARED {
m.shared--
}
m.lock = lock
return nil
}
func (m *memFile) CheckReservedLock() (bool, error) {
if m.lock >= vfs.LOCK_RESERVED {
return true, nil
}
m.lockMtx.Lock()
defer m.lockMtx.Unlock()
return m.reserved != nil, nil
}
func (*memFile) SectorSize() int {
return sectorSize
}
func (*memFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_ATOMIC |
vfs.IOCAP_SEQUENTIAL |
vfs.IOCAP_SAFE_APPEND |
vfs.IOCAP_POWERSAFE_OVERWRITE
}
func (m *memFile) SizeHint(size int64) error {
m.dataMtx.Lock()
defer m.dataMtx.Unlock()
if size > m.size {
return m.truncate(size)
}
return nil
}
func (m *memFile) LockState() vfs.LockLevel {
return m.lock
}
func divRoundUp(a, b int64) int64 {
return (a + b - 1) / b
}
func modRoundUp(a, b int64) int64 {
return b - (b-a%b)%b
}
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = zero
}
}

BIN
vfs/memdb/testdata/test.db vendored Normal file

Binary file not shown.

View File

@@ -1,6 +1,6 @@
//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd)
package sqlite3vfs
package vfs
import (
"os"

View File

@@ -1,6 +1,6 @@
//go:build !sqlite3_bsd
package sqlite3vfs
package vfs
import (
"io"

View File

@@ -1,4 +1,4 @@
package sqlite3vfs
package vfs
import (
"os"

View File

@@ -1,6 +1,6 @@
//go:build linux || illumos
package sqlite3vfs
package vfs
import (
"os"

View File

@@ -1,6 +1,6 @@
//go:build !linux && (!darwin || sqlite3_bsd)
package sqlite3vfs
package vfs
import (
"io"

View File

@@ -1,10 +1,11 @@
//go:build unix
package sqlite3vfs
package vfs
import (
"io/fs"
"os"
"syscall"
"time"
"golang.org/x/sys/unix"
@@ -25,6 +26,18 @@ func osAccess(path string, flags AccessFlag) error {
return unix.Access(path, access)
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
if sys, ok := fi.Sys().(*syscall.Stat_t); ok {
file.Chown(int(sys.Uid), int(sys.Gid))
}
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Test the PENDING lock before acquiring a new SHARED lock.
if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending {

View File

@@ -1,4 +1,4 @@
package sqlite3vfs
package vfs
import (
"io/fs"
@@ -47,6 +47,15 @@ func osAccess(path string, flags AccessFlag) error {
return nil
}
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}
func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode {
// Acquire the PENDING lock temporarily before acquiring a new SHARED lock.
rc := osReadLock(file, _PENDING_BYTE, 1, timeout)

5
vfs/readervfs/README.md Normal file
View File

@@ -0,0 +1,5 @@
# Go `"reader"` SQLite VFS
This package implements a `"reader"` SQLite VFS
that allows accessing any [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt)
as an immutable SQLite database.

84
vfs/readervfs/api.go Normal file
View File

@@ -0,0 +1,84 @@
// Package readervfs implements an SQLite VFS for immutable databases.
//
// The "reader" [vfs.VFS] permits accessing any [io.ReaderAt]
// as an immutable SQLite database.
//
// Importing package readervfs registers the VFS.
//
// import _ "github.com/ncruces/go-sqlite3/vfs/readervfs"
package readervfs
import (
"io"
"io/fs"
"sync"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("reader", readerVFS{})
}
var (
readerMtx sync.RWMutex
// +checklocks:readerMtx
readerDBs = map[string]SizeReaderAt{}
)
// Create creates an immutable database from reader.
// The caller should ensure that data from reader does not mutate,
// otherwise SQLite might return incorrect query results and/or [sqlite3.CORRUPT] errors.
func Create(name string, reader SizeReaderAt) {
readerMtx.Lock()
defer readerMtx.Unlock()
readerDBs[name] = reader
}
// Delete deletes a shared memory database.
func Delete(name string) {
readerMtx.Lock()
defer readerMtx.Unlock()
delete(readerDBs, name)
}
// A SizeReaderAt is a ReaderAt with a Size method.
// Use [NewSizeReaderAt] to adapt different Size interfaces.
type SizeReaderAt interface {
Size() (int64, error)
io.ReaderAt
}
// NewSizeReaderAt returns a SizeReaderAt given an io.ReaderAt
// that implements one of:
// - Size() (int64, error)
// - Size() int64
// - Len() int
// - Stat() (fs.FileInfo, error)
// - Seek(offset int64, whence int) (int64, error)
func NewSizeReaderAt(r io.ReaderAt) SizeReaderAt {
return sizer{r}
}
type sizer struct{ io.ReaderAt }
func (s sizer) Size() (int64, error) {
switch s := s.ReaderAt.(type) {
case interface{ Size() (int64, error) }:
return s.Size()
case interface{ Size() int64 }:
return s.Size(), nil
case interface{ Len() int }:
return int64(s.Len()), nil
case interface{ Stat() (fs.FileInfo, error) }:
fi, err := s.Stat()
if err != nil {
return 0, err
}
return fi.Size(), nil
case io.Seeker:
return s.Seek(0, io.SeekEnd)
}
return 0, sqlite3.IOERR_SEEK
}

View File

@@ -0,0 +1,95 @@
package readervfs_test
import (
"bytes"
"database/sql"
"fmt"
"log"
_ "embed"
_ "github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/vfs/readervfs"
"github.com/psanford/httpreadat"
)
//go:embed testdata/test.db
var testDB []byte
func Example_http() {
readervfs.Create("demo.db", httpreadat.New("https://www.sanford.io/demo.db"))
defer readervfs.Delete("demo.db")
db, err := sql.Open("sqlite3", "file:demo.db?vfs=reader")
if err != nil {
log.Fatal(err)
}
defer db.Close()
magname := map[int]string{
3: "thousand",
6: "million",
9: "billion",
}
rows, err := db.Query(`
SELECT period, data_value, magntude, units FROM csv
WHERE period > '2010'
LIMIT 10`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var period, units string
var value int64
var mag int
err = rows.Scan(&period, &value, &mag, &units)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s: %d %s %s\n", period, value, magname[mag], units)
}
// Output:
// 2010.03: 17463 million Dollars
// 2010.06: 17260 million Dollars
// 2010.09: 15419 million Dollars
// 2010.12: 17088 million Dollars
// 2011.03: 18516 million Dollars
// 2011.06: 18835 million Dollars
// 2011.09: 16390 million Dollars
// 2011.12: 18748 million Dollars
// 2012.03: 18477 million Dollars
// 2012.06: 18270 million Dollars
}
func Example_embed() {
readervfs.Create("test.db", readervfs.NewSizeReaderAt(bytes.NewReader(testDB)))
defer readervfs.Delete("test.db")
db, err := sql.Open("sqlite3", "file:test.db?vfs=reader")
if err != nil {
log.Fatal(err)
}
defer db.Close()
rows, err := db.Query(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id, name string
err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%s %s\n", id, name)
}
// Output:
// 0 go
// 1 zig
// 2 whatever
}

74
vfs/readervfs/reader.go Normal file
View File

@@ -0,0 +1,74 @@
package readervfs
import (
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
type readerVFS struct{}
// Open implements the [vfs.VFS] interface.
func (readerVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 {
return nil, flags, sqlite3.CANTOPEN
}
readerMtx.RLock()
defer readerMtx.RUnlock()
if ra, ok := readerDBs[name]; ok {
return readerFile{ra}, flags | vfs.OPEN_READONLY, nil
}
return nil, flags, sqlite3.CANTOPEN
}
// Delete implements the [vfs.VFS] interface.
func (readerVFS) Delete(name string, dirSync bool) error {
return sqlite3.IOERR_DELETE
}
// Access implements the [vfs.VFS] interface.
func (readerVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return false, nil
}
// FullPathname implements the [vfs.VFS] interface.
func (readerVFS) FullPathname(name string) (string, error) {
return name, nil
}
type readerFile struct{ SizeReaderAt }
func (readerFile) Close() error {
return nil
}
func (readerFile) WriteAt(b []byte, off int64) (n int, err error) {
return 0, sqlite3.READONLY
}
func (readerFile) Truncate(size int64) error {
return sqlite3.READONLY
}
func (readerFile) Sync(flag vfs.SyncFlag) error {
return nil
}
func (readerFile) Lock(lock vfs.LockLevel) error {
return nil
}
func (readerFile) Unlock(lock vfs.LockLevel) error {
return nil
}
func (readerFile) CheckReservedLock() (bool, error) {
return false, nil
}
func (readerFile) SectorSize() int {
return 0
}
func (readerFile) DeviceCharacteristics() vfs.DeviceCharacteristic {
return vfs.IOCAP_IMMUTABLE
}

View File

@@ -0,0 +1,87 @@
package readervfs
import (
"io"
"os"
"path/filepath"
"strings"
"testing"
)
func TestNewSizeReaderAt(t *testing.T) {
f, err := os.Create(filepath.Join(t.TempDir(), "abc.txt"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
n, err := NewSizeReaderAt(f).Size()
if err != nil {
t.Fatal(err)
}
if n != 0 {
t.Errorf("got %d", n)
}
reader := strings.NewReader("abc")
n, err = NewSizeReaderAt(reader).Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
n, err = NewSizeReaderAt(readlener{reader, reader.Len()}).Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
n, err = NewSizeReaderAt(readsizer{reader, reader.Size()}).Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
n, err = NewSizeReaderAt(readseeker{reader, reader}).Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
_, err = NewSizeReaderAt(readerat{reader}).Size()
if err == nil {
t.Error("want error")
}
}
type readlener struct {
io.ReaderAt
len int
}
func (l readlener) Len() int { return l.len }
type readsizer struct {
io.ReaderAt
size int64
}
func (l readsizer) Size() (int64, error) { return l.size, nil }
type readseeker struct {
io.ReaderAt
io.Seeker
}
type readerat struct {
io.ReaderAt
}

BIN
vfs/readervfs/testdata/test.db vendored Normal file

Binary file not shown.

47
vfs/registry.go Normal file
View File

@@ -0,0 +1,47 @@
package vfs
import "sync"
var (
// +checklocks:vfsRegistryMtx
vfsRegistry map[string]VFS
vfsRegistryMtx sync.RWMutex
)
// Find returns a VFS given its name.
// If there is no match, nil is returned.
// If name is empty, the default VFS is returned.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Find(name string) VFS {
if name == "" || name == "os" {
return vfsOS{}
}
vfsRegistryMtx.RLock()
defer vfsRegistryMtx.RUnlock()
return vfsRegistry[name]
}
// Register registers a VFS.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Register(name string, vfs VFS) {
if name == "" || name == "os" {
return
}
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
if vfsRegistry == nil {
vfsRegistry = map[string]VFS{}
}
vfsRegistry[name] = vfs
}
// Unregister unregisters a VFS.
//
// https://www.sqlite.org/c3ref/vfs_find.html
func Unregister(name string) {
vfsRegistryMtx.Lock()
defer vfsRegistryMtx.Unlock()
delete(vfsRegistry, name)
}

View File

@@ -1,18 +1,18 @@
package sqlite3vfs_test
package vfs_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
"github.com/ncruces/go-sqlite3/vfs"
)
type testVFS struct {
*testing.T
}
func (t testVFS) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, sqlite3vfs.OpenFlag, error) {
func (t testVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
t.Log("Open", name, flags)
t.SkipNow()
return nil, flags, nil
@@ -23,7 +23,7 @@ func (t testVFS) Delete(name string, syncDir bool) error {
return nil
}
func (t testVFS) Access(name string, flags sqlite3vfs.AccessFlag) (bool, error) {
func (t testVFS) Access(name string, flags vfs.AccessFlag) (bool, error) {
t.Log("Access", name, flags)
return true, nil
}
@@ -34,9 +34,8 @@ func (t testVFS) FullPathname(name string) (string, error) {
}
func TestRegister(t *testing.T) {
vfs := testVFS{t}
sqlite3vfs.Register("foo", vfs)
defer sqlite3vfs.Unregister("foo")
vfs.Register("foo", testVFS{t})
defer vfs.Unregister("foo")
conn, err := sqlite3.Open("file:file.db?vfs=foo")
if err != nil {

View File

@@ -16,7 +16,8 @@ import (
"sync/atomic"
"testing"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
"github.com/ncruces/go-sqlite3/vfs"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
@@ -34,13 +35,12 @@ var (
instances atomic.Uint64
)
func init() {
ctx := context.TODO()
rt = wazero.NewRuntime(ctx)
func TestMain(m *testing.M) {
ctx := context.Background()
rt = wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfig().WithDebugInfoEnabled(false))
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := sqlite3vfs.ExportHostFunctions(rt.NewHostModuleBuilder("env"))
env := vfs.ExportHostFunctions(rt.NewHostModuleBuilder("env"))
env.NewFunctionBuilder().WithFunc(system).Export("system")
_, err := env.Instantiate(ctx)
if err != nil {
@@ -51,6 +51,8 @@ func init() {
if err != nil {
panic(err)
}
os.Exit(m.Run())
}
func config(ctx context.Context) wazero.ModuleConfig {
@@ -80,7 +82,7 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
cfg := config(ctx).WithArgs(args...)
go func() {
ctx, vfs := sqlite3vfs.NewContext(ctx)
ctx, vfs := vfs.NewContext(ctx)
mod, _ := rt.InstantiateModule(ctx, module, cfg)
mod.Close(ctx)
vfs.Close()
@@ -89,7 +91,7 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
}
func Test_config01(t *testing.T) {
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
ctx, vfs := vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -108,7 +110,7 @@ func Test_config02(t *testing.T) {
t.Skip("skipping in CI")
}
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
ctx, vfs := vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "config02.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -124,7 +126,7 @@ func Test_crash01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
ctx, vfs := vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "crash01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -140,7 +142,7 @@ func Test_multiwrite01(t *testing.T) {
t.Skip("skipping in short mode")
}
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
ctx, vfs := vfs.NewContext(newContext(t))
name := filepath.Join(t.TempDir(), "test.db")
cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test")
mod, err := rt.InstantiateModule(ctx, module, cfg)
@@ -151,6 +153,38 @@ func Test_multiwrite01(t *testing.T) {
vfs.Close()
}
func Test_config01_memory(t *testing.T) {
ctx, vfs := vfs.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "test.db",
"config01.test",
"--vfs", "memdb",
"--timeout", "1000")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01_memory(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, vfs := vfs.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "/test.db",
"multiwrite01.test",
"--vfs", "memdb",
"--timeout", "1000")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func newContext(t *testing.T) context.Context {
return context.WithValue(context.Background(), logger{}, &testWriter{T: t})
}
@@ -158,7 +192,9 @@ func newContext(t *testing.T) context.Context {
type logger struct{}
type testWriter struct {
// +checklocks:mtx
*testing.T
// +checklocks:mtx
buf []byte
mtx sync.Mutex
}

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

3
vfs/tests/mptest/testdata/mptest.wasm vendored Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53bb204be0ddd312c23976e7b627ad62fb384d220557f4e3f723d89273d78a01
size 1477009

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/rand"
"flag"
"io"
"os"
"path/filepath"
@@ -17,7 +18,8 @@ import (
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
"github.com/ncruces/go-sqlite3/vfs"
_ "github.com/ncruces/go-sqlite3/vfs/memdb"
)
//go:embed testdata/speedtest1.wasm
@@ -30,12 +32,13 @@ var (
options []string
)
func init() {
ctx := context.TODO()
func TestMain(m *testing.M) {
initFlags()
rt = wazero.NewRuntime(ctx)
ctx := context.Background()
rt = wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfig().WithDebugInfoEnabled(false))
wasi_snapshot_preview1.MustInstantiate(ctx, rt)
env := sqlite3vfs.ExportHostFunctions(rt.NewHostModuleBuilder("env"))
env := vfs.ExportHostFunctions(rt.NewHostModuleBuilder("env"))
_, err := env.Instantiate(ctx)
if err != nil {
panic(err)
@@ -45,29 +48,33 @@ func init() {
if err != nil {
panic(err)
}
code := m.Run()
defer os.Exit(code)
io.Copy(os.Stderr, &output)
}
func TestMain(m *testing.M) {
func initFlags() {
i := 1
options = append(options, "speedtest1")
for _, arg := range os.Args[1:] {
if strings.HasPrefix(arg, "-test.") {
switch {
case strings.HasPrefix(arg, "-test."):
// keep test flags
os.Args[i] = arg
i++
} else {
default:
// collect everything else
options = append(options, arg)
}
}
os.Args = os.Args[:i]
code := m.Run()
io.Copy(os.Stderr, &output)
os.Exit(code)
flag.Parse()
}
func Benchmark_speedtest1(b *testing.B) {
output.Reset()
ctx, vfs := sqlite3vfs.NewContext(context.Background())
ctx, vfs := vfs.NewContext(context.Background())
name := filepath.Join(b.TempDir(), "test.db")
args := append(options, "--size", strconv.Itoa(b.N), name)
cfg := wazero.NewModuleConfig().

View File

@@ -4,7 +4,7 @@ set -euo pipefail
cd -P -- "$(dirname -- "$0")"
ROOT=../../../../
BINARYEN="$ROOT/tools/binaryen-version_112/bin"
BINARYEN="$ROOT/tools/binaryen-version_113/bin"
WASI_SDK="$ROOT/tools/wasi-sdk-20.0/bin"
"$WASI_SDK/clang" --target=wasm32-wasi -flto -g0 -O2 \

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:23ac5877bc34cf5ff06ed2cf1c5ac148e5d364d6d01c85bd0f99d14f0a681830
size 1513516

View File

@@ -1,9 +1,10 @@
package sqlite3vfs
package vfs
import (
"context"
"crypto/rand"
"io"
"net/url"
"reflect"
"time"
@@ -13,10 +14,10 @@ import (
"github.com/tetratelabs/wazero/api"
)
// ExportHostFunctions is an internal API users need not call directly.
//
// ExportHostFunctions registers the required VFS host functions
// with the provided env module.
//
// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly.
func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_vfs_find", vfsFind)
util.ExportFuncIIJ(env, "go_localtime", vfsLocaltime)
@@ -48,17 +49,15 @@ type vfsState struct {
files []File
}
// NewContext is an internal API users need not call directly.
//
// NewContext creates a new context to hold [api.Module] specific VFS data.
//
// This context should be passed to any [api.Function] calls that might
// The context should be passed to any [api.Function] calls that might
// generate VFS host callbacks.
//
// The returned [io.Closer] should be closed after the [api.Module] is closed,
// to release any associated resources.
//
// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly.
func NewContext(ctx context.Context) (context.Context, io.Closer) {
vfs := &vfsState{}
vfs := new(vfsState)
return context.WithValue(ctx, vfsKey{}, vfs), vfs
}
@@ -74,7 +73,7 @@ func (vfs *vfsState) Close() error {
func vfsFind(ctx context.Context, mod api.Module, zVfsName uint32) uint32 {
name := util.ReadString(mod, zVfsName, _MAX_STRING)
if Find(name) != nil {
if vfs := Find(name); vfs != nil && vfs != (vfsOS{}) {
return 1
}
return 0
@@ -172,7 +171,27 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
path = util.ReadString(mod, zPath, _MAX_PATHNAME)
}
file, flags, err := vfs.Open(path, flags)
var file File
var err error
var parsed bool
var params url.Values
if pfs, ok := vfs.(VFSParams); ok {
parsed = true
params = vfsURIParameters(ctx, mod, zPath, flags)
file, flags, err = pfs.OpenParams(path, flags, params)
} else {
file, flags, err = vfs.Open(path, flags)
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if !parsed {
params = vfsURIParameters(ctx, mod, zPath, flags)
}
if b, ok := util.ParseBool(params.Get("psow")); ok {
file.SetPowersafeOverwrite(b)
}
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
@@ -187,7 +206,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode {
err := vfsFileClose(ctx, mod, pFile)
if err != nil {
return _IOERR_CLOSE
return vfsErrorCode(err, _IOERR_CLOSE)
}
return _OK
}
@@ -200,12 +219,10 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
if n == int(iAmt) {
return _OK
}
if n == 0 && err != io.EOF {
return _IOERR_READ
}
for i := range buf[n:] {
buf[n+i] = 0
if err != io.EOF {
return vfsErrorCode(err, _IOERR_READ)
}
clear(buf[n:])
return _IOERR_SHORT_READ
}
@@ -215,7 +232,7 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOf
_, err := file.WriteAt(buf, iOfst)
if err != nil {
return _IOERR_WRITE
return vfsErrorCode(err, _IOERR_WRITE)
}
return _OK
}
@@ -234,7 +251,7 @@ func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags SyncFlag)
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
size, err := file.FileSize()
size, err := file.Size()
util.WriteUint64(mod, pSize, uint64(size))
return vfsErrorCode(err, _IOERR_SEEK)
}
@@ -318,12 +335,35 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
util.WriteUint32(mod, pArg, res)
return vfsErrorCode(err, _IOERR_FSTAT)
}
case _FCNTL_COMMIT_PHASETWO:
if file, ok := file.(FileCommitPhaseTwo); ok {
err := file.CommitPhaseTwo()
return vfsErrorCode(err, _IOERR)
}
case _FCNTL_BEGIN_ATOMIC_WRITE:
if file, ok := file.(FileBatchAtomicWrite); ok {
err := file.BeginAtomicWrite()
return vfsErrorCode(err, _IOERR_BEGIN_ATOMIC)
}
case _FCNTL_COMMIT_ATOMIC_WRITE:
if file, ok := file.(FileBatchAtomicWrite); ok {
err := file.CommitAtomicWrite()
return vfsErrorCode(err, _IOERR_COMMIT_ATOMIC)
}
case _FCNTL_ROLLBACK_ATOMIC_WRITE:
if file, ok := file.(FileBatchAtomicWrite); ok {
err := file.RollbackAtomicWrite()
return vfsErrorCode(err, _IOERR_ROLLBACK_ATOMIC)
}
}
// Consider also implementing these opcodes (in use by SQLite):
// _FCNTL_BUSYHANDLER
// _FCNTL_COMMIT_PHASETWO
// _FCNTL_PDB
// _FCNTL_BUSYHANDLER
// _FCNTL_CHUNK_SIZE
// _FCNTL_OVERWRITE
// _FCNTL_PRAGMA
// _FCNTL_SYNC
return _NOTFOUND
@@ -339,14 +379,51 @@ func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32)
return file.DeviceCharacteristics()
}
func vfsGet(mod api.Module, pVfs uint32) VFS {
if pVfs == 0 {
return vfsOS{}
func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags OpenFlag) url.Values {
if flags&OPEN_URI == 0 {
return nil
}
const zNameOffset = 16
name := util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_STRING)
if name == "os" {
return vfsOS{}
uriParam := mod.ExportedFunction("sqlite3_uri_parameter")
uriKey := mod.ExportedFunction("sqlite3_uri_key")
if uriParam == nil || uriKey == nil {
return nil
}
var stack [2]uint64
var params url.Values
for i := 0; ; i++ {
stack[1] = uint64(i)
stack[0] = uint64(zPath)
if err := uriKey.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if stack[0] == 0 {
return params
}
key := util.ReadString(mod, uint32(stack[0]), _MAX_STRING)
if params.Has(key) {
continue
}
stack[1] = stack[0]
stack[0] = uint64(zPath)
if err := uriParam.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if params == nil {
params = url.Values{}
}
params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_STRING))
}
}
func vfsGet(mod api.Module, pVfs uint32) VFS {
var name string
if pVfs != 0 {
const zNameOffset = 16
name = util.ReadString(mod, util.ReadUint32(mod, pVfs+zNameOffset), _MAX_STRING)
}
if vfs := Find(name); vfs != nil {
return vfs
@@ -354,6 +431,42 @@ func vfsGet(mod api.Module, pVfs uint32) VFS {
panic(util.NoVFSErr + util.ErrorString(name))
}
func vfsFileNew(vfs *vfsState, file File) uint32 {
// Find an empty slot.
for id, f := range vfs.files {
if f == nil {
vfs.files[id] = file
return uint32(id)
}
}
// Add a new slot.
vfs.files = append(vfs.files, file)
return uint32(len(vfs.files) - 1)
}
func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file File) {
const fileHandleOffset = 4
id := vfsFileNew(ctx.Value(vfsKey{}).(*vfsState), file)
util.WriteUint32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
return vfs.files[id]
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {
const fileHandleOffset = 4
vfs := ctx.Value(vfsKey{}).(*vfsState)
id := util.ReadUint32(mod, pFile+fileHandleOffset)
file := vfs.files[id]
vfs.files[id] = nil
return file.Close()
}
func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
if err == nil {
return _OK
@@ -364,3 +477,9 @@ func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
}
return def
}
func clear(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -1,4 +1,4 @@
package sqlite3vfs
package vfs
import (
"bytes"
@@ -14,10 +14,11 @@ import (
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/julianday"
"github.com/tetratelabs/wazero/experimental/wazerotest"
)
func Test_vfsLocaltime(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx := context.TODO()
tm := time.Now()
@@ -53,7 +54,7 @@ func Test_vfsLocaltime(t *testing.T) {
}
func Test_vfsRandomness(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx := context.TODO()
rc := vfsRandomness(ctx, mod, 0, 16, 4)
@@ -68,7 +69,7 @@ func Test_vfsRandomness(t *testing.T) {
}
func Test_vfsSleep(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx := context.TODO()
now := time.Now()
@@ -84,7 +85,7 @@ func Test_vfsSleep(t *testing.T) {
}
func Test_vfsCurrentTime(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx := context.TODO()
now := time.Now()
@@ -100,7 +101,7 @@ func Test_vfsCurrentTime(t *testing.T) {
}
func Test_vfsCurrentTime64(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx := context.TODO()
now := time.Now()
@@ -118,7 +119,7 @@ func Test_vfsCurrentTime64(t *testing.T) {
}
func Test_vfsFullPathname(t *testing.T) {
mod := util.NewMockModule(128 + _MAX_PATHNAME)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
util.WriteString(mod, 4, ".")
ctx := context.TODO()
@@ -147,7 +148,7 @@ func Test_vfsDelete(t *testing.T) {
}
file.Close()
mod := util.NewMockModule(128 + _MAX_PATHNAME)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
util.WriteString(mod, 4, name)
ctx := context.TODO()
@@ -168,7 +169,7 @@ func Test_vfsDelete(t *testing.T) {
func Test_vfsAccess(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(t.TempDir(), "test.db")
file := filepath.Join(dir, "test.db")
if f, err := os.Create(file); err != nil {
t.Fatal(err)
} else {
@@ -178,7 +179,7 @@ func Test_vfsAccess(t *testing.T) {
t.Fatal(err)
}
mod := util.NewMockModule(128 + _MAX_PATHNAME)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
util.WriteString(mod, 8, dir)
ctx := context.TODO()
@@ -198,6 +199,15 @@ func Test_vfsAccess(t *testing.T) {
t.Error("can't access directory")
}
util.WriteString(mod, 8, file)
rc = vfsAccess(ctx, mod, 0, 8, ACCESS_READ, 4)
if rc != _OK {
t.Fatal("returned", rc)
}
if got := util.ReadUint32(mod, 4); got != 1 {
t.Error("can't access file")
}
util.WriteString(mod, 8, file)
rc = vfsAccess(ctx, mod, 0, 8, ACCESS_READWRITE, 4)
if rc != _OK {
@@ -209,7 +219,7 @@ func Test_vfsAccess(t *testing.T) {
}
func Test_vfsFile(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()
@@ -282,7 +292,7 @@ func Test_vfsFile(t *testing.T) {
}
func Test_vfsFile_psow(t *testing.T) {
mod := util.NewMockModule(128)
mod := wazerotest.NewModule(wazerotest.NewMemory(wazerotest.PageSize))
ctx, vfs := NewContext(context.TODO())
defer vfs.Close()