diff --git a/blob.go b/blob.go index 5a0ee61..d0d9c3e 100644 --- a/blob.go +++ b/blob.go @@ -1,6 +1,10 @@ package sqlite3 -import "io" +import ( + "io" + + "github.com/ncruces/go-sqlite3/internal/util" +) // ZeroBlob represents a zero-filled, length n BLOB // that can be used as an argument to @@ -46,7 +50,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, } blob := Blob{c: c} - blob.handle = c.mem.readUint32(blobPtr) + blob.handle = util.ReadUint32(c.mod, blobPtr) blob.bytes = int64(c.call(c.api.blobBytes, uint64(blob.handle))[0]) return &blob, nil } @@ -98,7 +102,7 @@ func (b *Blob) Read(p []byte) (n int, err error) { return 0, err } - mem := b.c.mem.view(ptr, uint64(want)) + mem := util.View(b.c.mod, ptr, uint64(want)) copy(p, mem) b.offset += want if b.offset >= b.bytes { @@ -133,7 +137,7 @@ func (b *Blob) Write(p []byte) (n int, err error) { func (b *Blob) Seek(offset int64, whence int) (int64, error) { switch whence { default: - return 0, whenceErr + return 0, util.WhenceErr case io.SeekStart: break case io.SeekCurrent: @@ -142,7 +146,7 @@ func (b *Blob) Seek(offset int64, whence int) (int64, error) { offset += b.bytes } if offset < 0 { - return 0, offsetErr + return 0, util.OffsetErr } b.offset = offset return offset, nil diff --git a/conn.go b/conn.go index b1e11c5..a9e575f 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,8 @@ import ( "strings" "sync/atomic" "unsafe" + + "github.com/ncruces/go-sqlite3/internal/util" ) // Conn is a database connection handle. @@ -55,7 +57,7 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { if conn == nil { mod.close() } else { - runtime.SetFinalizer(conn, finalizer[Conn](3)) + runtime.SetFinalizer(conn, util.Finalizer[Conn](3)) } }() @@ -76,7 +78,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) { flags |= OPEN_EXRESCODE r := c.call(c.api.open, uint64(namePtr), uint64(connPtr), uint64(flags), 0) - handle := c.mem.readUint32(connPtr) + handle := util.ReadUint32(c.mod, connPtr) if err := c.module.error(r[0], handle); err != nil { c.closeDB(handle) return 0, err @@ -182,8 +184,8 @@ func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail str uint64(stmtPtr), uint64(tailPtr)) stmt = &Stmt{c: c} - stmt.handle = c.mem.readUint32(stmtPtr) - i := c.mem.readUint32(tailPtr) + stmt.handle = util.ReadUint32(c.mod, stmtPtr) + i := util.ReadUint32(c.mod, tailPtr) tail = sql[i-sqlPtr:] if err := c.error(r[0], sql); err != nil { @@ -275,7 +277,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { break case <-ctx.Done(): // Done was closed. - buf := c.mem.view(c.handle+c.api.interrupt, 4) + buf := util.View(c.mod, c.handle+c.api.interrupt, 4) (*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1) // Wait for the next call to SetInterrupt. <-waiter @@ -291,7 +293,7 @@ func (c *Conn) checkInterrupt() bool { if c.interrupt == nil || c.interrupt.Err() == nil { return false } - buf := c.mem.view(c.handle+c.api.interrupt, 4) + buf := util.View(c.mod, c.handle+c.api.interrupt, 4) (*atomic.Uint32)(unsafe.Pointer(&buf[0])).Store(1) return true } diff --git a/const.go b/const.go index 928305d..7739b8d 100644 --- a/const.go +++ b/const.go @@ -7,8 +7,6 @@ const ( _ROW = 100 /* sqlite3_step() has another row ready */ _DONE = 101 /* sqlite3_step() has finished executing */ - _OK_SYMLINK = (_OK | (2 << 8)) /* internal use only */ - _UTF8 = 1 _MAX_STRING = 512 // Used for short strings: names, error messages… @@ -211,65 +209,3 @@ func (t Datatype) String() string { } return strconv.FormatUint(uint64(t), 10) } - -type _AccessFlag uint32 - -const ( - _ACCESS_EXISTS _AccessFlag = 0 - _ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */ - _ACCESS_READ _AccessFlag = 2 /* Unused */ -) - -type _SyncFlag uint32 - -const ( - _SYNC_NORMAL _SyncFlag = 0x00002 - _SYNC_FULL _SyncFlag = 0x00003 - _SYNC_DATAONLY _SyncFlag = 0x00010 -) - -type _FcntlOpcode uint32 - -const ( - _FCNTL_LOCKSTATE = 1 - _FCNTL_GET_LOCKPROXYFILE = 2 - _FCNTL_SET_LOCKPROXYFILE = 3 - _FCNTL_LAST_ERRNO = 4 - _FCNTL_SIZE_HINT = 5 - _FCNTL_CHUNK_SIZE = 6 - _FCNTL_FILE_POINTER = 7 - _FCNTL_SYNC_OMITTED = 8 - _FCNTL_WIN32_AV_RETRY = 9 - _FCNTL_PERSIST_WAL = 10 - _FCNTL_OVERWRITE = 11 - _FCNTL_VFSNAME = 12 - _FCNTL_POWERSAFE_OVERWRITE = 13 - _FCNTL_PRAGMA = 14 - _FCNTL_BUSYHANDLER = 15 - _FCNTL_TEMPFILENAME = 16 - _FCNTL_MMAP_SIZE = 18 - _FCNTL_TRACE = 19 - _FCNTL_HAS_MOVED = 20 - _FCNTL_SYNC = 21 - _FCNTL_COMMIT_PHASETWO = 22 - _FCNTL_WIN32_SET_HANDLE = 23 - _FCNTL_WAL_BLOCK = 24 - _FCNTL_ZIPVFS = 25 - _FCNTL_RBU = 26 - _FCNTL_VFS_POINTER = 27 - _FCNTL_JOURNAL_POINTER = 28 - _FCNTL_WIN32_GET_HANDLE = 29 - _FCNTL_PDB = 30 - _FCNTL_BEGIN_ATOMIC_WRITE = 31 - _FCNTL_COMMIT_ATOMIC_WRITE = 32 - _FCNTL_ROLLBACK_ATOMIC_WRITE = 33 - _FCNTL_LOCK_TIMEOUT = 34 - _FCNTL_DATA_VERSION = 35 - _FCNTL_SIZE_LIMIT = 36 - _FCNTL_CKPT_DONE = 37 - _FCNTL_RESERVE_BYTES = 38 - _FCNTL_CKPT_START = 39 - _FCNTL_EXTERNAL_READER = 40 - _FCNTL_CKSM_FILE = 41 - _FCNTL_RESET_CACHE = 42 -) diff --git a/driver/driver.go b/driver/driver.go index 24e28c1..8d0305b 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -35,6 +35,7 @@ import ( "time" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" ) func init() { @@ -134,7 +135,7 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro switch opts.Isolation { default: - return nil, isolationErr + return nil, util.IsolationErr case driver.IsolationLevel(sql.LevelDefault), driver.IsolationLevel(sql.LevelSerializable): @@ -183,7 +184,7 @@ func (c conn) Prepare(query string) (driver.Stmt, error) { if st != nil { s.Close() st.Close() - return nil, tailErr + return nil, util.TailErr } } return stmt{s, c.conn}, nil @@ -316,7 +317,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive case nil: err = s.stmt.BindNull(id) default: - panic(assertErr) + panic(util.AssertErr()) } } if err != nil { @@ -394,7 +395,7 @@ func (r rows) Next(dest []driver.Value) error { dest[i] = nil } default: - panic(assertErr) + panic(util.AssertErr()) } } diff --git a/driver/driver_test.go b/driver/driver_test.go index 0b3c5ba..2656309 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" ) func Test_Open_dir(t *testing.T) { @@ -142,7 +143,7 @@ func Test_BeginTx(t *testing.T) { defer db.Close() _, err = db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) - if err.Error() != string(isolationErr) { + if err.Error() != string(util.IsolationErr) { t.Error("want isolationErr") } @@ -230,7 +231,7 @@ func Test_Prepare(t *testing.T) { } _, err = db.Prepare(`SELECT 1; SELECT 2`) - if err.Error() != string(tailErr) { + if err.Error() != string(util.TailErr) { t.Error("want tailErr") } } diff --git a/driver/error.go b/driver/error.go deleted file mode 100644 index 8af54b1..0000000 --- a/driver/error.go +++ /dev/null @@ -1,11 +0,0 @@ -package driver - -type errorString string - -func (e errorString) Error() string { return string(e) } - -const ( - assertErr = errorString("sqlite3: assertion failed") - tailErr = errorString("sqlite3: multiple statements") - isolationErr = errorString("sqlite3: unsupported isolation level") -) diff --git a/error.go b/error.go index 0fa5bb5..4d00a6e 100644 --- a/error.go +++ b/error.go @@ -1,8 +1,6 @@ package sqlite3 import ( - "fmt" - "runtime" "strconv" "strings" ) @@ -188,36 +186,3 @@ func (e ExtendedErrorCode) Temporary() bool { func (e ExtendedErrorCode) Timeout() bool { return e == BUSY_TIMEOUT } - -type errorString string - -func (e errorString) Error() string { return string(e) } - -const ( - nilErr = errorString("sqlite3: invalid memory address or null pointer dereference") - oomErr = errorString("sqlite3: out of memory") - rangeErr = errorString("sqlite3: index out of range") - noNulErr = errorString("sqlite3: missing NUL terminator") - noGlobalErr = errorString("sqlite3: could not find global: ") - noFuncErr = errorString("sqlite3: could not find function: ") - binaryErr = errorString("sqlite3: no SQLite binary embed/set/loaded") - timeErr = errorString("sqlite3: invalid time value") - whenceErr = errorString("sqlite3: invalid whence") - offsetErr = errorString("sqlite3: invalid offset") -) - -func assertErr() errorString { - msg := "sqlite3: assertion failed" - if _, file, line, ok := runtime.Caller(1); ok { - msg += " (" + file + ":" + strconv.Itoa(line) + ")" - } - return errorString(msg) -} - -func finalizer[T any](skip int) func(*T) { - msg := fmt.Sprintf("sqlite3: %T not closed", new(T)) - if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 { - msg += " (" + file + ":" + strconv.Itoa(line) + ")" - } - return func(*T) { panic(errorString(msg)) } -} diff --git a/error_test.go b/error_test.go index 927b5ad..b56fb38 100644 --- a/error_test.go +++ b/error_test.go @@ -4,11 +4,13 @@ import ( "errors" "strings" "testing" + + "github.com/ncruces/go-sqlite3/internal/util" ) func Test_assertErr(t *testing.T) { - err := assertErr() - if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:10)") { + err := util.AssertErr() + if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") { t.Errorf("got %q", s) } } @@ -120,7 +122,7 @@ func Test_ErrorCode_Error(t *testing.T) { for i := 0; i == int(ErrorCode(i)); i++ { want := "sqlite3: " r := db.call(db.api.errstr, uint64(i)) - want += db.mem.readString(uint32(r[0]), _MAX_STRING) + want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING) got := ErrorCode(i).Error() if got != want { @@ -142,7 +144,7 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { for i := 0; i == int(ExtendedErrorCode(i)); i++ { want := "sqlite3: " r := db.call(db.api.errstr, uint64(i)) - want += db.mem.readString(uint32(r[0]), _MAX_STRING) + want += util.ReadString(db.mod, uint32(r[0]), _MAX_STRING) got := ExtendedErrorCode(i).Error() if got != want { diff --git a/go.mod b/go.mod index 03361db..e11cf69 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/ncruces/julianday v0.1.5 - github.com/tetratelabs/wazero v1.0.0 + github.com/tetratelabs/wazero v1.0.1 golang.org/x/sync v0.1.0 golang.org/x/sys v0.6.0 ) diff --git a/go.sum b/go.sum index 8200a38..bb54896 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk= github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= -github.com/tetratelabs/wazero v1.0.0 h1:sCE9+mjFex95Ki6hdqwvhyF25x5WslADjDKIFU5BXzI= -github.com/tetratelabs/wazero v1.0.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= +github.com/tetratelabs/wazero v1.0.1 h1:xyWBoGyMjYekG3mEQ/W7xm9E05S89kJ/at696d/9yuc= +github.com/tetratelabs/wazero v1.0.1/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= diff --git a/internal/util/error.go b/internal/util/error.go new file mode 100644 index 0000000..351aa3f --- /dev/null +++ b/internal/util/error.go @@ -0,0 +1,42 @@ +package util + +import ( + "fmt" + "runtime" + "strconv" +) + +type ErrorString string + +func (e ErrorString) Error() string { return string(e) } + +const ( + NilErr = ErrorString("sqlite3: invalid memory address or null pointer dereference") + OOMErr = ErrorString("sqlite3: out of memory") + RangeErr = ErrorString("sqlite3: index out of range") + NoNulErr = ErrorString("sqlite3: missing NUL terminator") + NoGlobalErr = ErrorString("sqlite3: could not find global: ") + NoFuncErr = ErrorString("sqlite3: could not find function: ") + BinaryErr = ErrorString("sqlite3: no SQLite binary embed/set/loaded") + TimeErr = ErrorString("sqlite3: invalid time value") + WhenceErr = ErrorString("sqlite3: invalid whence") + OffsetErr = ErrorString("sqlite3: invalid offset") + TailErr = ErrorString("sqlite3: multiple statements") + IsolationErr = ErrorString("sqlite3: unsupported isolation level") +) + +func AssertErr() ErrorString { + msg := "sqlite3: assertion failed" + if _, file, line, ok := runtime.Caller(1); ok { + msg += " (" + file + ":" + strconv.Itoa(line) + ")" + } + return ErrorString(msg) +} + +func Finalizer[T any](skip int) func(*T) { + msg := fmt.Sprintf("sqlite3: %T not closed", new(T)) + if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 { + msg += " (" + file + ":" + strconv.Itoa(line) + ")" + } + return func(*T) { panic(ErrorString(msg)) } +} diff --git a/internal/util/mem.go b/internal/util/mem.go new file mode 100644 index 0000000..54a7264 --- /dev/null +++ b/internal/util/mem.go @@ -0,0 +1,143 @@ +package util + +import ( + "bytes" + "math" + + "github.com/tetratelabs/wazero/api" +) + +func View(mod api.Module, ptr uint32, size uint64) []byte { + if ptr == 0 { + panic(NilErr) + } + if size > math.MaxUint32 { + panic(RangeErr) + } + buf, ok := mod.Memory().Read(ptr, uint32(size)) + if !ok { + panic(RangeErr) + } + return buf +} + +func ReadUint8(mod api.Module, ptr uint32) uint8 { + if ptr == 0 { + panic(NilErr) + } + v, ok := mod.Memory().ReadByte(ptr) + if !ok { + panic(RangeErr) + } + return v +} + +func WriteUint8(mod api.Module, ptr uint32, v uint8) { + if ptr == 0 { + panic(NilErr) + } + ok := mod.Memory().WriteByte(ptr, v) + if !ok { + panic(RangeErr) + } +} + +func ReadUint32(mod api.Module, ptr uint32) uint32 { + if ptr == 0 { + panic(NilErr) + } + v, ok := mod.Memory().ReadUint32Le(ptr) + if !ok { + panic(RangeErr) + } + return v +} + +func WriteUint32(mod api.Module, ptr uint32, v uint32) { + if ptr == 0 { + panic(NilErr) + } + ok := mod.Memory().WriteUint32Le(ptr, v) + if !ok { + panic(RangeErr) + } +} + +func ReadUint64(mod api.Module, ptr uint32) uint64 { + if ptr == 0 { + panic(NilErr) + } + v, ok := mod.Memory().ReadUint64Le(ptr) + if !ok { + panic(RangeErr) + } + return v +} + +func WriteUint64(mod api.Module, ptr uint32, v uint64) { + if ptr == 0 { + panic(NilErr) + } + ok := mod.Memory().WriteUint64Le(ptr, v) + if !ok { + panic(RangeErr) + } +} + +func ReadBool8(mod api.Module, ptr uint32) bool { + return ReadUint8(mod, ptr) != 0 +} + +func WriteBool8(mod api.Module, ptr uint32, v bool) { + var b uint8 + if v { + b = 1 + } + WriteUint8(mod, ptr, b) +} + +func ReadFloat64(mod api.Module, ptr uint32) float64 { + return math.Float64frombits(ReadUint64(mod, ptr)) +} + +func WriteFloat64(mod api.Module, ptr uint32, v float64) { + WriteUint64(mod, ptr, math.Float64bits(v)) +} + +func ReadString(mod api.Module, ptr, maxlen uint32) string { + if ptr == 0 { + panic(NilErr) + } + switch maxlen { + case 0: + return "" + case math.MaxUint32: + // avoid overflow + default: + maxlen = maxlen + 1 + } + mem := mod.Memory() + buf, ok := mem.Read(ptr, maxlen) + if !ok { + buf, ok = mem.Read(ptr, mem.Size()-ptr) + if !ok { + panic(RangeErr) + } + } + if i := bytes.IndexByte(buf, 0); i < 0 { + panic(NoNulErr) + } else { + return string(buf[:i]) + } +} + +func WriteBytes(mod api.Module, ptr uint32, b []byte) { + buf := View(mod, ptr, uint64(len(b))) + copy(buf, b) +} + +func WriteString(mod api.Module, ptr uint32, s string) { + buf := View(mod, ptr, uint64(len(s)+1)) + buf[len(s)] = 0 + copy(buf, s) +} diff --git a/internal/util/mem_test.go b/internal/util/mem_test.go new file mode 100644 index 0000000..4b32f70 --- /dev/null +++ b/internal/util/mem_test.go @@ -0,0 +1,90 @@ +package util + +import ( + "math" + "testing" +) + +func TestView_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + View(mock, 0, 8) + t.Error("want panic") +} + +func TestView_range(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + View(mock, 126, 8) + t.Error("want panic") +} + +func TestView_overflow(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + View(mock, 1, math.MaxInt64) + t.Error("want panic") +} + +func TestReadUint32_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + ReadUint32(mock, 0) + t.Error("want panic") +} + +func TestReadUint32_range(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + ReadUint32(mock, 126) + t.Error("want panic") +} + +func TestReadUint64_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + ReadUint64(mock, 0) + t.Error("want panic") +} + +func TestReadUint64_range(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + ReadUint64(mock, 126) + t.Error("want panic") +} + +func TestWriteUint32_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + WriteUint32(mock, 0, 1) + t.Error("want panic") +} + +func TestWriteUint32_range(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + WriteUint32(mock, 126, 1) + t.Error("want panic") +} + +func TestWriteUint64_nil(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + WriteUint64(mock, 0, 1) + t.Error("want panic") +} + +func TestWriteUint64_range(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + WriteUint64(mock, 126, 1) + t.Error("want panic") +} + +func TestReadString_range(t *testing.T) { + defer func() { _ = recover() }() + mock := NewMockModule(128) + ReadString(mock, 130, math.MaxUint32) + t.Error("want panic") +} diff --git a/mock_test.go b/internal/util/mock.go similarity index 96% rename from mock_test.go rename to internal/util/mock.go index fcb5ccc..ed1604b 100644 --- a/mock_test.go +++ b/internal/util/mock.go @@ -1,4 +1,4 @@ -package sqlite3 +package util import ( "context" @@ -8,13 +8,9 @@ import ( "github.com/tetratelabs/wazero/api" ) -func init() { - Path = "./embed/sqlite3.wasm" -} - -func newMemory(size uint32) memory { +func NewMockModule(size uint32) api.Module { mem := make(mockMemory, size) - return memory{mockModule{&mem}} + return mockModule{&mem} } type mockModule struct { diff --git a/internal/vfs/const.go b/internal/vfs/const.go new file mode 100644 index 0000000..c8c480a --- /dev/null +++ b/internal/vfs/const.go @@ -0,0 +1,148 @@ +package vfs + +const ( + _MAX_PATHNAME = 512 + + ptrlen = 4 +) + +type _ErrorCode uint32 + +const ( + _OK _ErrorCode = 0 /* Successful result */ + _PERM _ErrorCode = 3 /* Access permission denied */ + _BUSY _ErrorCode = 5 /* The database file is locked */ + _IOERR _ErrorCode = 10 /* Some kind of disk I/O error occurred */ + _NOTFOUND _ErrorCode = 12 /* Unknown opcode in sqlite3_file_control() */ + _CANTOPEN _ErrorCode = 14 /* Unable to open the database file */ + + _IOERR_READ = _IOERR | (1 << 8) + _IOERR_SHORT_READ = _IOERR | (2 << 8) + _IOERR_WRITE = _IOERR | (3 << 8) + _IOERR_FSYNC = _IOERR | (4 << 8) + _IOERR_DIR_FSYNC = _IOERR | (5 << 8) + _IOERR_TRUNCATE = _IOERR | (6 << 8) + _IOERR_FSTAT = _IOERR | (7 << 8) + _IOERR_UNLOCK = _IOERR | (8 << 8) + _IOERR_RDLOCK = _IOERR | (9 << 8) + _IOERR_DELETE = _IOERR | (10 << 8) + _IOERR_BLOCKED = _IOERR | (11 << 8) + _IOERR_NOMEM = _IOERR | (12 << 8) + _IOERR_ACCESS = _IOERR | (13 << 8) + _IOERR_CHECKRESERVEDLOCK = _IOERR | (14 << 8) + _IOERR_LOCK = _IOERR | (15 << 8) + _IOERR_CLOSE = _IOERR | (16 << 8) + _IOERR_DIR_CLOSE = _IOERR | (17 << 8) + _IOERR_SHMOPEN = _IOERR | (18 << 8) + _IOERR_SHMSIZE = _IOERR | (19 << 8) + _IOERR_SHMLOCK = _IOERR | (20 << 8) + _IOERR_SHMMAP = _IOERR | (21 << 8) + _IOERR_SEEK = _IOERR | (22 << 8) + _IOERR_DELETE_NOENT = _IOERR | (23 << 8) + _IOERR_MMAP = _IOERR | (24 << 8) + _IOERR_GETTEMPPATH = _IOERR | (25 << 8) + _IOERR_CONVPATH = _IOERR | (26 << 8) + _IOERR_VNODE = _IOERR | (27 << 8) + _IOERR_AUTH = _IOERR | (28 << 8) + _IOERR_BEGIN_ATOMIC = _IOERR | (29 << 8) + _IOERR_COMMIT_ATOMIC = _IOERR | (30 << 8) + _IOERR_ROLLBACK_ATOMIC = _IOERR | (31 << 8) + _IOERR_DATA = _IOERR | (32 << 8) + _IOERR_CORRUPTFS = _IOERR | (33 << 8) + _CANTOPEN_NOTEMPDIR = _CANTOPEN | (1 << 8) + _CANTOPEN_ISDIR = _CANTOPEN | (2 << 8) + _CANTOPEN_FULLPATH = _CANTOPEN | (3 << 8) + _CANTOPEN_CONVPATH = _CANTOPEN | (4 << 8) + _CANTOPEN_DIRTYWAL = _CANTOPEN | (5 << 8) /* Not Used */ + _CANTOPEN_SYMLINK = _CANTOPEN | (6 << 8) + _OK_SYMLINK = (_OK | (2 << 8)) /* internal use only */ +) + +type _OpenFlag uint32 + +const ( + _OPEN_READONLY _OpenFlag = 0x00000001 /* Ok for sqlite3_open_v2() */ + _OPEN_READWRITE _OpenFlag = 0x00000002 /* Ok for sqlite3_open_v2() */ + _OPEN_CREATE _OpenFlag = 0x00000004 /* Ok for sqlite3_open_v2() */ + _OPEN_DELETEONCLOSE _OpenFlag = 0x00000008 /* VFS only */ + _OPEN_EXCLUSIVE _OpenFlag = 0x00000010 /* VFS only */ + _OPEN_AUTOPROXY _OpenFlag = 0x00000020 /* VFS only */ + _OPEN_URI _OpenFlag = 0x00000040 /* Ok for sqlite3_open_v2() */ + _OPEN_MEMORY _OpenFlag = 0x00000080 /* Ok for sqlite3_open_v2() */ + _OPEN_MAIN_DB _OpenFlag = 0x00000100 /* VFS only */ + _OPEN_TEMP_DB _OpenFlag = 0x00000200 /* VFS only */ + _OPEN_TRANSIENT_DB _OpenFlag = 0x00000400 /* VFS only */ + _OPEN_MAIN_JOURNAL _OpenFlag = 0x00000800 /* VFS only */ + _OPEN_TEMP_JOURNAL _OpenFlag = 0x00001000 /* VFS only */ + _OPEN_SUBJOURNAL _OpenFlag = 0x00002000 /* VFS only */ + _OPEN_SUPER_JOURNAL _OpenFlag = 0x00004000 /* VFS only */ + _OPEN_NOMUTEX _OpenFlag = 0x00008000 /* Ok for sqlite3_open_v2() */ + _OPEN_FULLMUTEX _OpenFlag = 0x00010000 /* Ok for sqlite3_open_v2() */ + _OPEN_SHAREDCACHE _OpenFlag = 0x00020000 /* Ok for sqlite3_open_v2() */ + _OPEN_PRIVATECACHE _OpenFlag = 0x00040000 /* Ok for sqlite3_open_v2() */ + _OPEN_WAL _OpenFlag = 0x00080000 /* VFS only */ + _OPEN_NOFOLLOW _OpenFlag = 0x01000000 /* Ok for sqlite3_open_v2() */ + _OPEN_EXRESCODE _OpenFlag = 0x02000000 /* Extended result codes */ +) + +type _AccessFlag uint32 + +const ( + _ACCESS_EXISTS _AccessFlag = 0 + _ACCESS_READWRITE _AccessFlag = 1 /* Used by PRAGMA temp_store_directory */ + _ACCESS_READ _AccessFlag = 2 /* Unused */ +) + +type _SyncFlag uint32 + +const ( + _SYNC_NORMAL _SyncFlag = 0x00002 + _SYNC_FULL _SyncFlag = 0x00003 + _SYNC_DATAONLY _SyncFlag = 0x00010 +) + +type _FcntlOpcode uint32 + +const ( + _FCNTL_LOCKSTATE = 1 + _FCNTL_GET_LOCKPROXYFILE = 2 + _FCNTL_SET_LOCKPROXYFILE = 3 + _FCNTL_LAST_ERRNO = 4 + _FCNTL_SIZE_HINT = 5 + _FCNTL_CHUNK_SIZE = 6 + _FCNTL_FILE_POINTER = 7 + _FCNTL_SYNC_OMITTED = 8 + _FCNTL_WIN32_AV_RETRY = 9 + _FCNTL_PERSIST_WAL = 10 + _FCNTL_OVERWRITE = 11 + _FCNTL_VFSNAME = 12 + _FCNTL_POWERSAFE_OVERWRITE = 13 + _FCNTL_PRAGMA = 14 + _FCNTL_BUSYHANDLER = 15 + _FCNTL_TEMPFILENAME = 16 + _FCNTL_MMAP_SIZE = 18 + _FCNTL_TRACE = 19 + _FCNTL_HAS_MOVED = 20 + _FCNTL_SYNC = 21 + _FCNTL_COMMIT_PHASETWO = 22 + _FCNTL_WIN32_SET_HANDLE = 23 + _FCNTL_WAL_BLOCK = 24 + _FCNTL_ZIPVFS = 25 + _FCNTL_RBU = 26 + _FCNTL_VFS_POINTER = 27 + _FCNTL_JOURNAL_POINTER = 28 + _FCNTL_WIN32_GET_HANDLE = 29 + _FCNTL_PDB = 30 + _FCNTL_BEGIN_ATOMIC_WRITE = 31 + _FCNTL_COMMIT_ATOMIC_WRITE = 32 + _FCNTL_ROLLBACK_ATOMIC_WRITE = 33 + _FCNTL_LOCK_TIMEOUT = 34 + _FCNTL_DATA_VERSION = 35 + _FCNTL_SIZE_LIMIT = 36 + _FCNTL_CKPT_DONE = 37 + _FCNTL_RESERVE_BYTES = 38 + _FCNTL_CKPT_START = 39 + _FCNTL_EXTERNAL_READER = 40 + _FCNTL_CKSM_FILE = 41 + _FCNTL_RESET_CACHE = 42 +) diff --git a/internal/vfs/func.go b/internal/vfs/func.go new file mode 100644 index 0000000..cd56a90 --- /dev/null +++ b/internal/vfs/func.go @@ -0,0 +1,78 @@ +package vfs + +import ( + "context" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +func registerFunc1[T0, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0) TR) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, T0(stack[0]))) + }), + []api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + +func registerFunc2[T0, T1, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) TR) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]))) + }), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + +func registerFunc3[T0, T1, T2, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2) TR) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))) + }), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + +func registerFunc4[T0, T1, T2, T3, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) TR) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))) + }), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + +func registerFunc5[T0, T1, T2, T3, T4, TR ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3, _ T4) TR) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))) + }), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + +func registerFuncRW(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _, _, _ uint32, _ int64) _ErrorCode) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, uint32(stack[0]), uint32(stack[1]), uint32(stack[2]), int64(stack[3]))) + }), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} + +func registerFuncT(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ uint32, _ int64) _ErrorCode) { + mod.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc( + func(ctx context.Context, mod api.Module, stack []uint64) { + stack[0] = uint64(fn(ctx, mod, uint32(stack[0]), int64(stack[1]))) + }), + []api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}). + Export(name) +} diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go new file mode 100644 index 0000000..5223231 --- /dev/null +++ b/internal/vfs/vfs.go @@ -0,0 +1,359 @@ +package vfs + +import ( + "context" + "crypto/rand" + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "time" + + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/julianday" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +func Instantiate(ctx context.Context, r wazero.Runtime) { + env := NewEnvModuleBuilder(r) + _, err := env.Instantiate(ctx) + if err != nil { + panic(err) + } +} + +func NewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder { + env := r.NewHostModuleBuilder("env") + registerFuncT(env, "os_localtime", vfsLocaltime) + registerFunc3(env, "os_randomness", vfsRandomness) + registerFunc2(env, "os_sleep", vfsSleep) + registerFunc2(env, "os_current_time", vfsCurrentTime) + registerFunc2(env, "os_current_time_64", vfsCurrentTime64) + registerFunc4(env, "os_full_pathname", vfsFullPathname) + registerFunc3(env, "os_delete", vfsDelete) + registerFunc4(env, "os_access", vfsAccess) + registerFunc5(env, "os_open", vfsOpen) + registerFunc1(env, "os_close", vfsClose) + registerFuncRW(env, "os_read", vfsRead) + registerFuncRW(env, "os_write", vfsWrite) + registerFuncT(env, "os_truncate", vfsTruncate) + registerFunc2(env, "os_sync", vfsSync) + registerFunc2(env, "os_file_size", vfsFileSize) + registerFunc2(env, "os_lock", vfsLock) + registerFunc2(env, "os_unlock", vfsUnlock) + registerFunc2(env, "os_check_reserved_lock", vfsCheckReservedLock) + registerFunc3(env, "os_file_control", vfsFileControl) + return env +} + +type vfsKey struct{} +type vfsState struct { + files []*os.File +} + +func Context(ctx context.Context) (context.Context, io.Closer) { + vfs := &vfsState{} + return context.WithValue(ctx, vfsKey{}, vfs), vfs +} + +func (vfs *vfsState) Close() error { + for _, f := range vfs.files { + if f != nil { + f.Close() + } + } + vfs.files = nil + return nil +} + +func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) _ErrorCode { + tm := time.Unix(t, 0) + var isdst int + if tm.IsDST() { + isdst = 1 + } + + // https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html + util.WriteUint32(mod, pTm+0*ptrlen, uint32(tm.Second())) + util.WriteUint32(mod, pTm+1*ptrlen, uint32(tm.Minute())) + util.WriteUint32(mod, pTm+2*ptrlen, uint32(tm.Hour())) + util.WriteUint32(mod, pTm+3*ptrlen, uint32(tm.Day())) + util.WriteUint32(mod, pTm+4*ptrlen, uint32(tm.Month()-time.January)) + util.WriteUint32(mod, pTm+5*ptrlen, uint32(tm.Year()-1900)) + util.WriteUint32(mod, pTm+6*ptrlen, uint32(tm.Weekday()-time.Sunday)) + util.WriteUint32(mod, pTm+7*ptrlen, uint32(tm.YearDay()-1)) + util.WriteUint32(mod, pTm+8*ptrlen, uint32(isdst)) + return _OK +} + +func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 { + mem := util.View(mod, zByte, uint64(nByte)) + n, _ := rand.Reader.Read(mem) + return uint32(n) +} + +func vfsSleep(ctx context.Context, mod api.Module, pVfs, nMicro uint32) _ErrorCode { + time.Sleep(time.Duration(nMicro) * time.Microsecond) + return _OK +} + +func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) _ErrorCode { + day := julianday.Float(time.Now()) + util.WriteFloat64(mod, prNow, day) + return _OK +} + +func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) _ErrorCode { + day, nsec := julianday.Date(time.Now()) + msec := day*86_400_000 + nsec/1_000_000 + util.WriteUint64(mod, piNow, uint64(msec)) + return _OK +} + +func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) _ErrorCode { + rel := util.ReadString(mod, zRelative, _MAX_PATHNAME) + abs, err := filepath.Abs(rel) + if err != nil { + return _CANTOPEN_FULLPATH + } + + size := uint64(len(abs) + 1) + if size > uint64(nFull) { + return _CANTOPEN_FULLPATH + } + mem := util.View(mod, zFull, size) + mem[len(abs)] = 0 + copy(mem, abs) + + if fi, err := os.Lstat(abs); err == nil { + if fi.Mode()&fs.ModeSymlink != 0 { + return _OK_SYMLINK + } + return _OK + } else if errors.Is(err, fs.ErrNotExist) { + return _OK + } + return _CANTOPEN_FULLPATH +} + +func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) _ErrorCode { + path := util.ReadString(mod, zPath, _MAX_PATHNAME) + err := os.Remove(path) + if errors.Is(err, fs.ErrNotExist) { + return _IOERR_DELETE_NOENT + } + if err != nil { + return _IOERR_DELETE + } + if runtime.GOOS != "windows" && syncDir != 0 { + f, err := os.Open(filepath.Dir(path)) + if err != nil { + return _OK + } + defer f.Close() + err = osSync(f, false, false) + if err != nil { + return _IOERR_DIR_FSYNC + } + } + return _OK +} + +func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) _ErrorCode { + path := util.ReadString(mod, zPath, _MAX_PATHNAME) + err := osAccess(path, flags) + + var res uint32 + var rc _ErrorCode + if flags == _ACCESS_EXISTS { + switch { + case err == nil: + res = 1 + case errors.Is(err, fs.ErrNotExist): + res = 0 + default: + rc = _IOERR_ACCESS + } + } else { + switch { + case err == nil: + res = 1 + case errors.Is(err, fs.ErrPermission): + res = 0 + default: + rc = _IOERR_ACCESS + } + } + + util.WriteUint32(mod, pResOut, res) + return rc +} + +func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags _OpenFlag, pOutFlags uint32) _ErrorCode { + var oflags int + if flags&_OPEN_EXCLUSIVE != 0 { + oflags |= os.O_EXCL + } + if flags&_OPEN_CREATE != 0 { + oflags |= os.O_CREATE + } + if flags&_OPEN_READONLY != 0 { + oflags |= os.O_RDONLY + } + if flags&_OPEN_READWRITE != 0 { + oflags |= os.O_RDWR + } + + var err error + var file *os.File + if zName == 0 { + file, err = os.CreateTemp("", "*.db") + } else { + name := util.ReadString(mod, zName, _MAX_PATHNAME) + file, err = osOpenFile(name, oflags, 0666) + } + if err != nil { + return _CANTOPEN + } + + if flags&_OPEN_DELETEONCLOSE != 0 { + os.Remove(file.Name()) + } + + openFile(ctx, mod, pFile, file) + + if flags&_OPEN_READONLY != 0 { + setFileReadOnly(ctx, mod, pFile, true) + } + if runtime.GOOS != "windows" && + flags&(_OPEN_CREATE) != 0 && + flags&(_OPEN_MAIN_JOURNAL|_OPEN_SUPER_JOURNAL|_OPEN_WAL) != 0 { + setFileSyncDir(ctx, mod, pFile, true) + } + + if pOutFlags != 0 { + util.WriteUint32(mod, pOutFlags, uint32(flags)) + } + return _OK +} + +func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode { + err := closeFile(ctx, mod, pFile) + if err != nil { + return _IOERR_CLOSE + } + return _OK +} + +func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode { + buf := util.View(mod, zBuf, uint64(iAmt)) + + file := getOSFile(ctx, mod, pFile) + n, err := file.ReadAt(buf, iOfst) + if n == int(iAmt) { + return _OK + } + if n == 0 && err != io.EOF { + return _IOERR_READ + } + for i := range buf[n:] { + buf[n+i] = 0 + } + return _IOERR_SHORT_READ +} + +func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) _ErrorCode { + buf := util.View(mod, zBuf, uint64(iAmt)) + + file := getOSFile(ctx, mod, pFile) + _, err := file.WriteAt(buf, iOfst) + if err != nil { + return _IOERR_WRITE + } + return _OK +} + +func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode { + file := getOSFile(ctx, mod, pFile) + err := file.Truncate(nByte) + if err != nil { + return _IOERR_TRUNCATE + } + return _OK +} + +func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) _ErrorCode { + dataonly := (flags & _SYNC_DATAONLY) != 0 + fullsync := (flags & 0x0f) == _SYNC_FULL + + file := getOSFile(ctx, mod, pFile) + err := osSync(file, fullsync, dataonly) + if err != nil { + return _IOERR_FSYNC + } + if runtime.GOOS != "windows" && getFileSyncDir(ctx, mod, pFile) { + setFileSyncDir(ctx, mod, pFile, false) + f, err := os.Open(filepath.Dir(file.Name())) + if err != nil { + return _OK + } + defer f.Close() + err = osSync(f, false, false) + if err != nil { + return _IOERR_DIR_FSYNC + } + } + return _OK +} + +func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode { + file := getOSFile(ctx, mod, pFile) + off, err := file.Seek(0, io.SeekEnd) + if err != nil { + return _IOERR_SEEK + } + + util.WriteUint64(mod, pSize, uint64(off)) + return _OK +} + +func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode { + switch op { + case _FCNTL_SIZE_HINT: + return vfsSizeHint(ctx, mod, pFile, pArg) + case _FCNTL_HAS_MOVED: + return vfsFileMoved(ctx, mod, pFile, pArg) + } + return _NOTFOUND +} + +func vfsSizeHint(ctx context.Context, mod api.Module, pFile, pArg uint32) _ErrorCode { + file := getOSFile(ctx, mod, pFile) + size := util.ReadUint64(mod, pArg) + err := osAllocate(file, int64(size)) + if err != nil { + return _IOERR_TRUNCATE + } + return _OK +} + +func vfsFileMoved(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode { + file := getOSFile(ctx, mod, pFile) + fi, err := file.Stat() + if err != nil { + return _IOERR_FSTAT + } + pi, err := os.Stat(file.Name()) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return _IOERR_FSTAT + } + var res uint32 + if !os.SameFile(fi, pi) { + res = 1 + } + util.WriteUint32(mod, pResOut, res) + return _OK +} diff --git a/internal/vfs/vfs_file.go b/internal/vfs/vfs_file.go new file mode 100644 index 0000000..27ae82e --- /dev/null +++ b/internal/vfs/vfs_file.go @@ -0,0 +1,82 @@ +package vfs + +import ( + "context" + "os" + "time" + + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/tetratelabs/wazero/api" +) + +const ( + // These need to match the offsets asserted in os.c + vfsFileIDOffset = 4 + vfsFileLockOffset = 8 + vfsFileSyncDirOffset = 10 + vfsFileReadOnlyOffset = 11 + vfsFileLockTimeoutOffset = 12 +) + +func newFileID(ctx context.Context, file *os.File) uint32 { + vfs := ctx.Value(vfsKey{}).(*vfsState) + + // Find an empty slot. + for id, ptr := range vfs.files { + if ptr == nil { + vfs.files[id] = file + return uint32(id) + } + } + + // Add a new slot. + vfs.files = append(vfs.files, file) + return uint32(len(vfs.files) - 1) +} + +func openFile(ctx context.Context, mod api.Module, pFile uint32, file *os.File) { + id := newFileID(ctx, file) + util.WriteUint32(mod, pFile+vfsFileIDOffset, id) +} + +func closeFile(ctx context.Context, mod api.Module, pFile uint32) error { + id := util.ReadUint32(mod, pFile+vfsFileIDOffset) + vfs := ctx.Value(vfsKey{}).(*vfsState) + file := vfs.files[id] + vfs.files[id] = nil + return file.Close() +} + +func getOSFile(ctx context.Context, mod api.Module, pFile uint32) *os.File { + id := util.ReadUint32(mod, pFile+vfsFileIDOffset) + vfs := ctx.Value(vfsKey{}).(*vfsState) + return vfs.files[id] +} + +func getFileLock(ctx context.Context, mod api.Module, pFile uint32) vfsLockState { + return vfsLockState(util.ReadUint8(mod, pFile+vfsFileLockOffset)) +} + +func setFileLock(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) { + util.WriteUint8(mod, pFile+vfsFileLockOffset, uint8(lock)) +} + +func getFileLockTimeout(ctx context.Context, mod api.Module, pFile uint32) time.Duration { + return time.Duration(util.ReadUint32(mod, pFile+vfsFileLockTimeoutOffset)) * time.Millisecond +} + +func getFileSyncDir(ctx context.Context, mod api.Module, pFile uint32) bool { + return util.ReadBool8(mod, pFile+vfsFileSyncDirOffset) +} + +func setFileSyncDir(ctx context.Context, mod api.Module, pFile uint32, val bool) { + util.WriteBool8(mod, pFile+vfsFileSyncDirOffset, val) +} + +func getFileReadOnly(ctx context.Context, mod api.Module, pFile uint32) bool { + return util.ReadBool8(mod, pFile+vfsFileReadOnlyOffset) +} + +func setFileReadOnly(ctx context.Context, mod api.Module, pFile uint32, val bool) { + util.WriteBool8(mod, pFile+vfsFileReadOnlyOffset, val) +} diff --git a/vfs_lock.go b/internal/vfs/vfs_lock.go similarity index 67% rename from vfs_lock.go rename to internal/vfs/vfs_lock.go index 4ccd3a3..fddf23f 100644 --- a/vfs_lock.go +++ b/internal/vfs/vfs_lock.go @@ -1,10 +1,11 @@ -package sqlite3 +package vfs import ( "context" "os" "time" + "github.com/ncruces/go-sqlite3/internal/util" "github.com/tetratelabs/wazero/api" ) @@ -56,27 +57,27 @@ const ( type vfsLockState uint32 -func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 { +func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) _ErrorCode { // Argument check. SQLite never explicitly requests a pending lock. if eLock != _SHARED_LOCK && eLock != _RESERVED_LOCK && eLock != _EXCLUSIVE_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } - file := vfsFile.GetOS(ctx, mod, pFile) - cLock := vfsFile.GetLock(ctx, mod, pFile) - timeout := vfsFile.GetLockTimeout(ctx, mod, pFile) - readOnly := vfsFile.GetReadOnly(ctx, mod, pFile) + file := getOSFile(ctx, mod, pFile) + cLock := getFileLock(ctx, mod, pFile) + timeout := getFileLockTimeout(ctx, mod, pFile) + readOnly := getFileReadOnly(ctx, mod, pFile) switch { case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK: // Connection state check. - panic(assertErr()) + panic(util.AssertErr()) case cLock == _NO_LOCK && eLock > _SHARED_LOCK: // We never move from unlocked to anything higher than a shared lock. - panic(assertErr()) + panic(util.AssertErr()) case cLock != _SHARED_LOCK && eLock == _RESERVED_LOCK: // A shared lock is always held when a reserved lock is requested. - panic(assertErr()) + panic(util.AssertErr()) } // If we already have an equal or more restrictive lock, do nothing. @@ -86,67 +87,67 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta // Do not allow any kind of write-lock on a read-only database. if readOnly && eLock > _RESERVED_LOCK { - return uint32(IOERR_LOCK) + return _IOERR_LOCK } switch eLock { case _SHARED_LOCK: // Must be unlocked to get SHARED. if cLock != _NO_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } - if rc := vfsOS.GetSharedLock(file, timeout); rc != _OK { - return uint32(rc) + if rc := osGetSharedLock(file, timeout); rc != _OK { + return rc } - vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK) + setFileLock(ctx, mod, pFile, _SHARED_LOCK) return _OK case _RESERVED_LOCK: // Must be SHARED to get RESERVED. if cLock != _SHARED_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } - if rc := vfsOS.GetReservedLock(file, timeout); rc != _OK { - return uint32(rc) + if rc := osGetReservedLock(file, timeout); rc != _OK { + return rc } - vfsFile.SetLock(ctx, mod, pFile, _RESERVED_LOCK) + setFileLock(ctx, mod, pFile, _RESERVED_LOCK) return _OK case _EXCLUSIVE_LOCK: // Must be SHARED, RESERVED or PENDING to get EXCLUSIVE. if cLock <= _NO_LOCK || cLock >= _EXCLUSIVE_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } // A PENDING lock is needed before acquiring an EXCLUSIVE lock. if cLock < _PENDING_LOCK { - if rc := vfsOS.GetPendingLock(file); rc != _OK { - return uint32(rc) + if rc := osGetPendingLock(file); rc != _OK { + return rc } - vfsFile.SetLock(ctx, mod, pFile, _PENDING_LOCK) + setFileLock(ctx, mod, pFile, _PENDING_LOCK) } - if rc := vfsOS.GetExclusiveLock(file, timeout); rc != _OK { - return uint32(rc) + if rc := osGetExclusiveLock(file, timeout); rc != _OK { + return rc } - vfsFile.SetLock(ctx, mod, pFile, _EXCLUSIVE_LOCK) + setFileLock(ctx, mod, pFile, _EXCLUSIVE_LOCK) return _OK default: - panic(assertErr()) + panic(util.AssertErr()) } } -func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) uint32 { +func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockState) _ErrorCode { // Argument check. if eLock != _NO_LOCK && eLock != _SHARED_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } - file := vfsFile.GetOS(ctx, mod, pFile) - cLock := vfsFile.GetLock(ctx, mod, pFile) + file := getOSFile(ctx, mod, pFile) + cLock := getFileLock(ctx, mod, pFile) // Connection state check. if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } // If we don't have a more restrictive lock, do nothing. @@ -156,58 +157,58 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS switch eLock { case _SHARED_LOCK: - if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK { - return uint32(rc) + if rc := osDowngradeLock(file, cLock); rc != _OK { + return rc } - vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK) + setFileLock(ctx, mod, pFile, _SHARED_LOCK) return _OK case _NO_LOCK: - rc := vfsOS.ReleaseLock(file, cLock) - vfsFile.SetLock(ctx, mod, pFile, _NO_LOCK) - return uint32(rc) + rc := osReleaseLock(file, cLock) + setFileLock(ctx, mod, pFile, _NO_LOCK) + return rc default: - panic(assertErr()) + panic(util.AssertErr()) } } -func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 { - file := vfsFile.GetOS(ctx, mod, pFile) - cLock := vfsFile.GetLock(ctx, mod, pFile) +func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode { + file := getOSFile(ctx, mod, pFile) + cLock := getFileLock(ctx, mod, pFile) // Connection state check. if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK { - panic(assertErr()) + panic(util.AssertErr()) } var locked bool - var rc xErrorCode + var rc _ErrorCode if cLock >= _RESERVED_LOCK { locked = true } else { - locked, rc = vfsOS.CheckReservedLock(file) + locked, rc = osCheckReservedLock(file) } var res uint32 if locked { res = 1 } - memory{mod}.writeUint32(pResOut, res) - return uint32(rc) + util.WriteUint32(mod, pResOut, res) + return rc } -func (vfsOSMethods) GetReservedLock(file *os.File, timeout time.Duration) xErrorCode { +func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode { // Acquire the RESERVED lock. - return vfsOS.writeLock(file, _RESERVED_BYTE, 1, timeout) + return osWriteLock(file, _RESERVED_BYTE, 1, timeout) } -func (vfsOSMethods) GetPendingLock(file *os.File) xErrorCode { +func osGetPendingLock(file *os.File) _ErrorCode { // Acquire the PENDING lock. - return vfsOS.writeLock(file, _PENDING_BYTE, 1, 0) + return osWriteLock(file, _PENDING_BYTE, 1, 0) } -func (vfsOSMethods) CheckReservedLock(file *os.File) (bool, xErrorCode) { +func osCheckReservedLock(file *os.File) (bool, _ErrorCode) { // Test the RESERVED lock. - return vfsOS.checkLock(file, _RESERVED_BYTE, 1) + return osCheckLock(file, _RESERVED_BYTE, 1) } diff --git a/vfs_lock_test.go b/internal/vfs/vfs_lock_test.go similarity index 51% rename from vfs_lock_test.go rename to internal/vfs/vfs_lock_test.go index cd2aa99..8812b39 100644 --- a/vfs_lock_test.go +++ b/internal/vfs/vfs_lock_test.go @@ -1,4 +1,4 @@ -package sqlite3 +package vfs import ( "context" @@ -6,6 +6,8 @@ import ( "path/filepath" "runtime" "testing" + + "github.com/ncruces/go-sqlite3/internal/util" ) func Test_vfsLock(t *testing.T) { @@ -37,133 +39,133 @@ func Test_vfsLock(t *testing.T) { pFile2 = 16 pOutput = 32 ) - mem := newMemory(128) - ctx, vfs := vfsContext(context.TODO()) + mod := util.NewMockModule(128) + ctx, vfs := Context(context.TODO()) defer vfs.Close() - vfsFile.Open(ctx, mem.mod, pFile1, file1) - vfsFile.Open(ctx, mem.mod, pFile2, file2) + openFile(ctx, mod, pFile1, file1) + openFile(ctx, mod, pFile2, file2) - rc := vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) + rc := vfsCheckReservedLock(ctx, mod, pFile1, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got != 0 { + if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile2, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got != 0 { + if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } - rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK) + rc = vfsLock(ctx, mod, pFile2, _SHARED_LOCK) if rc != _OK { t.Fatal("returned", rc) } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got != 0 { + if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile2, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got != 0 { + if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } - rc = vfsLock(ctx, mem.mod, pFile2, _RESERVED_LOCK) + rc = vfsLock(ctx, mod, pFile2, _RESERVED_LOCK) if rc != _OK { t.Fatal("returned", rc) } - rc = vfsLock(ctx, mem.mod, pFile2, _SHARED_LOCK) + rc = vfsLock(ctx, mod, pFile2, _SHARED_LOCK) if rc != _OK { t.Fatal("returned", rc) } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got == 0 { + if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile2, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got == 0 { + if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } - rc = vfsLock(ctx, mem.mod, pFile2, _EXCLUSIVE_LOCK) + rc = vfsLock(ctx, mod, pFile2, _EXCLUSIVE_LOCK) if rc != _OK { t.Fatal("returned", rc) } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got == 0 { + if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile2, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got == 0 { + if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } - rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK) + rc = vfsLock(ctx, mod, pFile1, _SHARED_LOCK) if rc == _OK { t.Fatal("returned", rc) } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got == 0 { + if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile2, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got == 0 { + if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } - rc = vfsUnlock(ctx, mem.mod, pFile2, _SHARED_LOCK) + rc = vfsUnlock(ctx, mod, pFile2, _SHARED_LOCK) if rc != _OK { t.Fatal("returned", rc) } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile1, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got != 0 { + if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } - rc = vfsCheckReservedLock(ctx, mem.mod, pFile2, pOutput) + rc = vfsCheckReservedLock(ctx, mod, pFile2, pOutput) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(pOutput); got != 0 { + if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } - rc = vfsLock(ctx, mem.mod, pFile1, _SHARED_LOCK) + rc = vfsLock(ctx, mod, pFile1, _SHARED_LOCK) if rc != _OK { t.Fatal("returned", rc) } diff --git a/internal/vfs/vfs_os_bsd.go b/internal/vfs/vfs_os_bsd.go new file mode 100644 index 0000000..2f04d3b --- /dev/null +++ b/internal/vfs/vfs_os_bsd.go @@ -0,0 +1,56 @@ +//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd) + +package vfs + +import ( + "os" + "time" + + "golang.org/x/sys/unix" +) + +func osUnlock(file *os.File, start, len int64) _ErrorCode { + if start == 0 && len == 0 { + err := unix.Flock(int(file.Fd()), unix.LOCK_UN) + if err != nil { + return _IOERR_UNLOCK + } + } + return _OK +} + +func osLock(file *os.File, how int, timeout time.Duration, def _ErrorCode) _ErrorCode { + var err error + for { + err = unix.Flock(int(file.Fd()), how) + if errno, _ := err.(unix.Errno); errno != unix.EAGAIN { + break + } + if timeout < time.Millisecond { + break + } + timeout -= time.Millisecond + time.Sleep(time.Millisecond) + } + return osLockErrorCode(err, def) +} + +func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { + return osLock(file, unix.LOCK_SH|unix.LOCK_NB, timeout, _IOERR_RDLOCK) +} + +func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { + return osLock(file, unix.LOCK_EX|unix.LOCK_NB, timeout, _IOERR_LOCK) +} + +func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) { + lock := unix.Flock_t{ + Type: unix.F_RDLCK, + Start: start, + Len: len, + } + if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil { + return false, _IOERR_CHECKRESERVEDLOCK + } + return lock.Type != unix.F_UNLCK, _OK +} diff --git a/vfs_os_darwin.go b/internal/vfs/vfs_os_darwin.go similarity index 66% rename from vfs_os_darwin.go rename to internal/vfs/vfs_os_darwin.go index 1c3bcc7..4062cd1 100644 --- a/vfs_os_darwin.go +++ b/internal/vfs/vfs_os_darwin.go @@ -1,6 +1,6 @@ //go:build !sqlite3_bsd -package sqlite3 +package vfs import ( "io" @@ -23,14 +23,14 @@ type flocktimeout_t struct { timeout unix.Timespec } -func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error { +func osSync(file *os.File, fullsync, dataonly bool) error { if fullsync { return file.Sync() } return unix.Fsync(int(file.Fd())) } -func (vfsOSMethods) Allocate(file *os.File, size int64) error { +func osAllocate(file *os.File, size int64) error { off, err := file.Seek(0, io.SeekEnd) if err != nil { return err @@ -57,19 +57,19 @@ func (vfsOSMethods) Allocate(file *os.File, size int64) error { return file.Truncate(size) } -func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode { +func osUnlock(file *os.File, start, len int64) _ErrorCode { err := unix.FcntlFlock(file.Fd(), _F_OFD_SETLK, &unix.Flock_t{ Type: unix.F_UNLCK, Start: start, Len: len, }) if err != nil { - return IOERR_UNLOCK + return _IOERR_UNLOCK } return _OK } -func (vfsOSMethods) lock(file *os.File, typ int16, start, len int64, timeout time.Duration, def xErrorCode) xErrorCode { +func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode { lock := flocktimeout_t{fl: unix.Flock_t{ Type: typ, Start: start, @@ -82,25 +82,25 @@ func (vfsOSMethods) lock(file *os.File, typ int16, start, len int64, timeout tim lock.timeout = unix.NsecToTimespec(int64(timeout / time.Nanosecond)) err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLKWTIMEOUT, &lock.fl) } - return vfsOS.lockErrorCode(err, def) + return osLockErrorCode(err, def) } -func (vfsOSMethods) readLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, unix.F_RDLCK, start, len, timeout, IOERR_RDLOCK) +func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { + return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK) } -func (vfsOSMethods) writeLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, unix.F_WRLCK, start, len, timeout, IOERR_LOCK) +func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { + return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK) } -func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) { +func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) { lock := unix.Flock_t{ Type: unix.F_RDLCK, Start: start, Len: len, } if unix.FcntlFlock(file.Fd(), _F_OFD_GETLK, &lock) != nil { - return false, IOERR_CHECKRESERVEDLOCK + return false, _IOERR_CHECKRESERVEDLOCK } return lock.Type != unix.F_UNLCK, _OK } diff --git a/vfs_os_linux.go b/internal/vfs/vfs_os_linux.go similarity index 64% rename from vfs_os_linux.go rename to internal/vfs/vfs_os_linux.go index 7c8f1be..2161380 100644 --- a/vfs_os_linux.go +++ b/internal/vfs/vfs_os_linux.go @@ -1,4 +1,4 @@ -package sqlite3 +package vfs import ( "os" @@ -6,7 +6,7 @@ import ( "golang.org/x/sys/unix" ) -func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error { +func osSync(file *os.File, fullsync, dataonly bool) error { if dataonly { _, _, err := unix.Syscall(unix.SYS_FDATASYNC, file.Fd(), 0, 0) if err != 0 { @@ -17,7 +17,7 @@ func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error { return file.Sync() } -func (vfsOSMethods) Allocate(file *os.File, size int64) error { +func osAllocate(file *os.File, size int64) error { if size == 0 { return nil } diff --git a/vfs_os_ofd.go b/internal/vfs/vfs_os_ofd.go similarity index 51% rename from vfs_os_ofd.go rename to internal/vfs/vfs_os_ofd.go index 9d81d33..27179d4 100644 --- a/vfs_os_ofd.go +++ b/internal/vfs/vfs_os_ofd.go @@ -1,6 +1,6 @@ //go:build linux || illumos -package sqlite3 +package vfs import ( "os" @@ -9,19 +9,19 @@ import ( "golang.org/x/sys/unix" ) -func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode { +func osUnlock(file *os.File, start, len int64) _ErrorCode { err := unix.FcntlFlock(file.Fd(), unix.F_OFD_SETLK, &unix.Flock_t{ Type: unix.F_UNLCK, Start: start, Len: len, }) if err != nil { - return IOERR_UNLOCK + return _IOERR_UNLOCK } return _OK } -func (vfsOSMethods) lock(file *os.File, typ int16, start, len int64, timeout time.Duration, def xErrorCode) xErrorCode { +func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode { lock := unix.Flock_t{ Type: typ, Start: start, @@ -39,25 +39,25 @@ func (vfsOSMethods) lock(file *os.File, typ int16, start, len int64, timeout tim timeout -= time.Millisecond time.Sleep(time.Millisecond) } - return vfsOS.lockErrorCode(err, def) + return osLockErrorCode(err, def) } -func (vfsOSMethods) readLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, unix.F_RDLCK, start, len, timeout, IOERR_RDLOCK) +func osReadLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { + return osLock(file, unix.F_RDLCK, start, len, timeout, _IOERR_RDLOCK) } -func (vfsOSMethods) writeLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, unix.F_WRLCK, start, len, timeout, IOERR_LOCK) +func osWriteLock(file *os.File, start, len int64, timeout time.Duration) _ErrorCode { + return osLock(file, unix.F_WRLCK, start, len, timeout, _IOERR_LOCK) } -func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) { +func osCheckLock(file *os.File, start, len int64) (bool, _ErrorCode) { lock := unix.Flock_t{ Type: unix.F_RDLCK, Start: start, Len: len, } if unix.FcntlFlock(file.Fd(), unix.F_OFD_GETLK, &lock) != nil { - return false, IOERR_CHECKRESERVEDLOCK + return false, _IOERR_CHECKRESERVEDLOCK } return lock.Type != unix.F_UNLCK, _OK } diff --git a/vfs_os_other.go b/internal/vfs/vfs_os_other.go similarity index 60% rename from vfs_os_other.go rename to internal/vfs/vfs_os_other.go index 3134761..23fe3b4 100644 --- a/vfs_os_other.go +++ b/internal/vfs/vfs_os_other.go @@ -1,17 +1,17 @@ //go:build !linux && (!darwin || sqlite3_bsd) -package sqlite3 +package vfs import ( "io" "os" ) -func (vfsOSMethods) Sync(file *os.File, fullsync, dataonly bool) error { +func osSync(file *os.File, fullsync, dataonly bool) error { return file.Sync() } -func (vfsOSMethods) Allocate(file *os.File, size int64) error { +func osAllocate(file *os.File, size int64) error { off, err := file.Seek(0, io.SeekEnd) if err != nil { return err diff --git a/vfs_os_unix.go b/internal/vfs/vfs_os_unix.go similarity index 50% rename from vfs_os_unix.go rename to internal/vfs/vfs_os_unix.go index 2ad7ed4..c609f50 100644 --- a/vfs_os_unix.go +++ b/internal/vfs/vfs_os_unix.go @@ -1,6 +1,6 @@ //go:build unix -package sqlite3 +package vfs import ( "io/fs" @@ -10,11 +10,11 @@ import ( "golang.org/x/sys/unix" ) -func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { +func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { return os.OpenFile(name, flag, perm) } -func (vfsOSMethods) Access(path string, flags _AccessFlag) error { +func osAccess(path string, flags _AccessFlag) error { var access uint32 // unix.F_OK switch flags { case _ACCESS_READWRITE: @@ -25,46 +25,46 @@ func (vfsOSMethods) Access(path string, flags _AccessFlag) error { return unix.Access(path, access) } -func (vfsOSMethods) GetSharedLock(file *os.File, timeout time.Duration) xErrorCode { +func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode { // Test the PENDING lock before acquiring a new SHARED lock. - if pending, _ := vfsOS.checkLock(file, _PENDING_BYTE, 1); pending { - return xErrorCode(BUSY) + if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending { + return _ErrorCode(_BUSY) } // Acquire the SHARED lock. - return vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) + return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) } -func (vfsOSMethods) GetExclusiveLock(file *os.File, timeout time.Duration) xErrorCode { +func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode { if timeout == 0 { timeout = time.Millisecond } // Acquire the EXCLUSIVE lock. - return vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) + return osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) } -func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode { +func osDowngradeLock(file *os.File, state vfsLockState) _ErrorCode { if state >= _EXCLUSIVE_LOCK { // Downgrade to a SHARED lock. - if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK { + if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK { // In theory, the downgrade to a SHARED cannot fail because another // process is holding an incompatible lock. If it does, this // indicates that the other process is not following the locking - // protocol. If this happens, return IOERR_RDLOCK. Returning + // protocol. If this happens, return _IOERR_RDLOCK. Returning // BUSY would confuse the upper layer. - return IOERR_RDLOCK + return _IOERR_RDLOCK } } // Release the PENDING and RESERVED locks. - return vfsOS.unlock(file, _PENDING_BYTE, 2) + return osUnlock(file, _PENDING_BYTE, 2) } -func (vfsOSMethods) ReleaseLock(file *os.File, _ vfsLockState) xErrorCode { +func osReleaseLock(file *os.File, _ vfsLockState) _ErrorCode { // Release all locks. - return vfsOS.unlock(file, 0, 0) + return osUnlock(file, 0, 0) } -func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode { +func osLockErrorCode(err error, def _ErrorCode) _ErrorCode { if err == nil { return _OK } @@ -78,9 +78,9 @@ func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode { unix.ENOLCK, unix.EDEADLK, unix.ETIMEDOUT: - return xErrorCode(BUSY) + return _ErrorCode(_BUSY) case unix.EPERM: - return xErrorCode(PERM) + return _ErrorCode(_PERM) } } return def diff --git a/vfs_os_windows.go b/internal/vfs/vfs_os_windows.go similarity index 67% rename from vfs_os_windows.go rename to internal/vfs/vfs_os_windows.go index 2132e18..e77d36e 100644 --- a/vfs_os_windows.go +++ b/internal/vfs/vfs_os_windows.go @@ -1,4 +1,4 @@ -package sqlite3 +package vfs import ( "io/fs" @@ -9,12 +9,12 @@ import ( "golang.org/x/sys/windows" ) -// OpenFile is a simplified copy of [os.openFileNolog] +// osOpenFile is a simplified copy of [os.openFileNolog] // that uses syscall.FILE_SHARE_DELETE. // https://go.dev/src/os/file_windows.go // // See: https://go.dev/issue/32088 -func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { +func osOpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) { if name == "" { return nil, &os.PathError{Op: "open", Path: name, Err: syscall.ENOENT} } @@ -25,7 +25,7 @@ func (vfsOSMethods) OpenFile(name string, flag int, perm fs.FileMode) (*os.File, return os.NewFile(uintptr(r), name), nil } -func (vfsOSMethods) Access(path string, flags _AccessFlag) error { +func osAccess(path string, flags _AccessFlag) error { fi, err := os.Stat(path) if err != nil { return err @@ -47,88 +47,88 @@ func (vfsOSMethods) Access(path string, flags _AccessFlag) error { return nil } -func (vfsOSMethods) GetSharedLock(file *os.File, timeout time.Duration) xErrorCode { +func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode { // Acquire the PENDING lock temporarily before acquiring a new SHARED lock. - rc := vfsOS.readLock(file, _PENDING_BYTE, 1, timeout) + rc := osReadLock(file, _PENDING_BYTE, 1, timeout) if rc == _OK { // Acquire the SHARED lock. - rc = vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0) + rc = osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0) // Release the PENDING lock. - vfsOS.unlock(file, _PENDING_BYTE, 1) + osUnlock(file, _PENDING_BYTE, 1) } return rc } -func (vfsOSMethods) GetExclusiveLock(file *os.File, timeout time.Duration) xErrorCode { +func osGetExclusiveLock(file *os.File, timeout time.Duration) _ErrorCode { if timeout == 0 { timeout = time.Millisecond } // Release the SHARED lock. - vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE) + osUnlock(file, _SHARED_FIRST, _SHARED_SIZE) // Acquire the EXCLUSIVE lock. - rc := vfsOS.writeLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) + rc := osWriteLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) if rc != _OK { // Reacquire the SHARED lock. - vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0) + osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0) } return rc } -func (vfsOSMethods) DowngradeLock(file *os.File, state vfsLockState) xErrorCode { +func osDowngradeLock(file *os.File, state vfsLockState) _ErrorCode { if state >= _EXCLUSIVE_LOCK { // Release the SHARED lock. - vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE) + osUnlock(file, _SHARED_FIRST, _SHARED_SIZE) // Reacquire the SHARED lock. - if rc := vfsOS.readLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK { + if rc := osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, 0); rc != _OK { // This should never happen. // We should always be able to reacquire the read lock. - return IOERR_RDLOCK + return _IOERR_RDLOCK } } // Release the PENDING and RESERVED locks. if state >= _RESERVED_LOCK { - vfsOS.unlock(file, _RESERVED_BYTE, 1) + osUnlock(file, _RESERVED_BYTE, 1) } if state >= _PENDING_LOCK { - vfsOS.unlock(file, _PENDING_BYTE, 1) + osUnlock(file, _PENDING_BYTE, 1) } return _OK } -func (vfsOSMethods) ReleaseLock(file *os.File, state vfsLockState) xErrorCode { +func osReleaseLock(file *os.File, state vfsLockState) _ErrorCode { // Release all locks. if state >= _RESERVED_LOCK { - vfsOS.unlock(file, _RESERVED_BYTE, 1) + osUnlock(file, _RESERVED_BYTE, 1) } if state >= _SHARED_LOCK { - vfsOS.unlock(file, _SHARED_FIRST, _SHARED_SIZE) + osUnlock(file, _SHARED_FIRST, _SHARED_SIZE) } if state >= _PENDING_LOCK { - vfsOS.unlock(file, _PENDING_BYTE, 1) + osUnlock(file, _PENDING_BYTE, 1) } return _OK } -func (vfsOSMethods) unlock(file *os.File, start, len uint32) xErrorCode { +func osUnlock(file *os.File, start, len uint32) _ErrorCode { err := windows.UnlockFileEx(windows.Handle(file.Fd()), 0, len, 0, &windows.Overlapped{Offset: start}) if err == windows.ERROR_NOT_LOCKED { return _OK } if err != nil { - return IOERR_UNLOCK + return _IOERR_UNLOCK } return _OK } -func (vfsOSMethods) lock(file *os.File, flags, start, len uint32, timeout time.Duration, def xErrorCode) xErrorCode { +func osLock(file *os.File, flags, start, len uint32, timeout time.Duration, def _ErrorCode) _ErrorCode { var err error for { err = windows.LockFileEx(windows.Handle(file.Fd()), flags, @@ -142,35 +142,35 @@ func (vfsOSMethods) lock(file *os.File, flags, start, len uint32, timeout time.D timeout -= time.Millisecond time.Sleep(time.Millisecond) } - return vfsOS.lockErrorCode(err, def) + return osLockErrorCode(err, def) } -func (vfsOSMethods) readLock(file *os.File, start, len uint32, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, +func osReadLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode { + return osLock(file, windows.LOCKFILE_FAIL_IMMEDIATELY, - start, len, timeout, IOERR_RDLOCK) + start, len, timeout, _IOERR_RDLOCK) } -func (vfsOSMethods) writeLock(file *os.File, start, len uint32, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, +func osWriteLock(file *os.File, start, len uint32, timeout time.Duration) _ErrorCode { + return osLock(file, windows.LOCKFILE_FAIL_IMMEDIATELY|windows.LOCKFILE_EXCLUSIVE_LOCK, - start, len, timeout, IOERR_LOCK) + start, len, timeout, _IOERR_LOCK) } -func (vfsOSMethods) checkLock(file *os.File, start, len uint32) (bool, xErrorCode) { - rc := vfsOS.lock(file, +func osCheckLock(file *os.File, start, len uint32) (bool, _ErrorCode) { + rc := osLock(file, windows.LOCKFILE_FAIL_IMMEDIATELY, - start, len, 0, IOERR_CHECKRESERVEDLOCK) - if rc == xErrorCode(BUSY) { + start, len, 0, _IOERR_CHECKRESERVEDLOCK) + if rc == _BUSY { return true, _OK } if rc == _OK { - vfsOS.unlock(file, start, len) + osUnlock(file, start, len) } return false, rc } -func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode { +func osLockErrorCode(err error, def _ErrorCode) _ErrorCode { if err == nil { return _OK } @@ -181,7 +181,7 @@ func (vfsOSMethods) lockErrorCode(err error, def xErrorCode) xErrorCode { windows.ERROR_LOCK_VIOLATION, windows.ERROR_IO_PENDING, windows.ERROR_OPERATION_ABORTED: - return xErrorCode(BUSY) + return _BUSY } } return def diff --git a/vfs_test.go b/internal/vfs/vfs_test.go similarity index 53% rename from vfs_test.go rename to internal/vfs/vfs_test.go index a638e60..1c064fe 100644 --- a/vfs_test.go +++ b/internal/vfs/vfs_test.go @@ -1,4 +1,4 @@ -package sqlite3 +package vfs import ( "bytes" @@ -11,66 +11,67 @@ import ( "testing" "time" + "github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/julianday" ) func Test_vfsLocaltime(t *testing.T) { - mem := newMemory(128) + mod := util.NewMockModule(128) ctx := context.TODO() tm := time.Now() - rc := vfsLocaltime(ctx, mem.mod, 4, tm.Unix()) + rc := vfsLocaltime(ctx, mod, 4, tm.Unix()) if rc != 0 { t.Fatal("returned", rc) } - if s := mem.readUint32(4 + 0*4); int(s) != tm.Second() { + if s := util.ReadUint32(mod, 4+0*4); int(s) != tm.Second() { t.Error("wrong second") } - if m := mem.readUint32(4 + 1*4); int(m) != tm.Minute() { + if m := util.ReadUint32(mod, 4+1*4); int(m) != tm.Minute() { t.Error("wrong minute") } - if h := mem.readUint32(4 + 2*4); int(h) != tm.Hour() { + if h := util.ReadUint32(mod, 4+2*4); int(h) != tm.Hour() { t.Error("wrong hour") } - if d := mem.readUint32(4 + 3*4); int(d) != tm.Day() { + if d := util.ReadUint32(mod, 4+3*4); int(d) != tm.Day() { t.Error("wrong day") } - if m := mem.readUint32(4 + 4*4); time.Month(1+m) != tm.Month() { + if m := util.ReadUint32(mod, 4+4*4); time.Month(1+m) != tm.Month() { t.Error("wrong month") } - if y := mem.readUint32(4 + 5*4); 1900+int(y) != tm.Year() { + if y := util.ReadUint32(mod, 4+5*4); 1900+int(y) != tm.Year() { t.Error("wrong year") } - if w := mem.readUint32(4 + 6*4); time.Weekday(w) != tm.Weekday() { + if w := util.ReadUint32(mod, 4+6*4); time.Weekday(w) != tm.Weekday() { t.Error("wrong weekday") } - if d := mem.readUint32(4 + 7*4); int(d) != tm.YearDay()-1 { + if d := util.ReadUint32(mod, 4+7*4); int(d) != tm.YearDay()-1 { t.Error("wrong yearday") } } func Test_vfsRandomness(t *testing.T) { - mem := newMemory(128) + mod := util.NewMockModule(128) ctx := context.TODO() - rc := vfsRandomness(ctx, mem.mod, 0, 16, 4) + rc := vfsRandomness(ctx, mod, 0, 16, 4) if rc != 16 { t.Fatal("returned", rc) } var zero [16]byte - if got := mem.view(4, 16); bytes.Equal(got, zero[:]) { + if got := util.View(mod, 4, 16); bytes.Equal(got, zero[:]) { t.Fatal("all zero") } } func Test_vfsSleep(t *testing.T) { - mem := newMemory(128) + mod := util.NewMockModule(128) ctx := context.TODO() now := time.Now() - rc := vfsSleep(ctx, mem.mod, 0, 123456) + rc := vfsSleep(ctx, mod, 0, 123456) if rc != 0 { t.Fatal("returned", rc) } @@ -82,56 +83,56 @@ func Test_vfsSleep(t *testing.T) { } func Test_vfsCurrentTime(t *testing.T) { - mem := newMemory(128) + mod := util.NewMockModule(128) ctx := context.TODO() now := time.Now() - rc := vfsCurrentTime(ctx, mem.mod, 0, 4) + rc := vfsCurrentTime(ctx, mod, 0, 4) if rc != 0 { t.Fatal("returned", rc) } want := julianday.Float(now) - if got := mem.readFloat64(4); float32(got) != float32(want) { + if got := util.ReadFloat64(mod, 4); float32(got) != float32(want) { t.Errorf("got %v, want %v", got, want) } } func Test_vfsCurrentTime64(t *testing.T) { - mem := newMemory(128) + mod := util.NewMockModule(128) ctx := context.TODO() now := time.Now() time.Sleep(time.Millisecond) - rc := vfsCurrentTime64(ctx, mem.mod, 0, 4) + rc := vfsCurrentTime64(ctx, mod, 0, 4) if rc != 0 { t.Fatal("returned", rc) } day, nsec := julianday.Date(now) want := day*86_400_000 + nsec/1_000_000 - if got := mem.readUint64(4); float32(got) != float32(want) { + if got := util.ReadUint64(mod, 4); float32(got) != float32(want) { t.Errorf("got %v, want %v", got, want) } } func Test_vfsFullPathname(t *testing.T) { - mem := newMemory(128 + _MAX_PATHNAME) - mem.writeString(4, ".") + mod := util.NewMockModule(128 + _MAX_PATHNAME) + util.WriteString(mod, 4, ".") ctx := context.TODO() - rc := vfsFullPathname(ctx, mem.mod, 0, 4, 0, 8) - if rc != uint32(CANTOPEN_FULLPATH) { - t.Errorf("returned %d, want %d", rc, CANTOPEN_FULLPATH) + rc := vfsFullPathname(ctx, mod, 0, 4, 0, 8) + if rc != _CANTOPEN_FULLPATH { + t.Errorf("returned %d, want %d", rc, _CANTOPEN_FULLPATH) } - rc = vfsFullPathname(ctx, mem.mod, 0, 4, _MAX_PATHNAME, 8) + rc = vfsFullPathname(ctx, mod, 0, 4, _MAX_PATHNAME, 8) if rc != _OK { t.Fatal("returned", rc) } want, _ := filepath.Abs(".") - if got := mem.readString(8, _MAX_PATHNAME); got != want { + if got := util.ReadString(mod, 8, _MAX_PATHNAME); got != want { t.Errorf("got %v, want %v", got, want) } } @@ -145,11 +146,11 @@ func Test_vfsDelete(t *testing.T) { } file.Close() - mem := newMemory(128 + _MAX_PATHNAME) - mem.writeString(4, name) + mod := util.NewMockModule(128 + _MAX_PATHNAME) + util.WriteString(mod, 4, name) ctx := context.TODO() - rc := vfsDelete(ctx, mem.mod, 0, 4, 1) + rc := vfsDelete(ctx, mod, 0, 4, 1) if rc != _OK { t.Fatal("returned", rc) } @@ -158,8 +159,8 @@ func Test_vfsDelete(t *testing.T) { t.Fatal("did not delete the file") } - rc = vfsDelete(ctx, mem.mod, 0, 4, 1) - if rc != uint32(IOERR_DELETE_NOENT) { + rc = vfsDelete(ctx, mod, 0, 4, 1) + if rc != _IOERR_DELETE_NOENT { t.Fatal("returned", rc) } } @@ -176,99 +177,99 @@ func Test_vfsAccess(t *testing.T) { t.Fatal(err) } - mem := newMemory(128 + _MAX_PATHNAME) - mem.writeString(8, dir) + mod := util.NewMockModule(128 + _MAX_PATHNAME) + util.WriteString(mod, 8, dir) ctx := context.TODO() - rc := vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_EXISTS, 4) + rc := vfsAccess(ctx, mod, 0, 8, _ACCESS_EXISTS, 4) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(4); got != 1 { + if got := util.ReadUint32(mod, 4); got != 1 { t.Error("directory did not exist") } - rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4) + rc = vfsAccess(ctx, mod, 0, 8, _ACCESS_READWRITE, 4) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(4); got != 1 { + if got := util.ReadUint32(mod, 4); got != 1 { t.Error("can't access directory") } - mem.writeString(8, file) - rc = vfsAccess(ctx, mem.mod, 0, 8, _ACCESS_READWRITE, 4) + util.WriteString(mod, 8, file) + rc = vfsAccess(ctx, mod, 0, 8, _ACCESS_READWRITE, 4) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(4); got != 0 { + if got := util.ReadUint32(mod, 4); got != 0 { t.Error("can access file") } } func Test_vfsFile(t *testing.T) { - mem := newMemory(128) - ctx, vfs := vfsContext(context.TODO()) + mod := util.NewMockModule(128) + ctx, vfs := Context(context.TODO()) defer vfs.Close() // Open a temporary file. - rc := vfsOpen(ctx, mem.mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0) + rc := vfsOpen(ctx, mod, 0, 0, 4, _OPEN_CREATE|_OPEN_EXCLUSIVE|_OPEN_READWRITE|_OPEN_DELETEONCLOSE, 0) if rc != _OK { t.Fatal("returned", rc) } // Write stuff. text := "Hello world!" - mem.writeString(16, text) - rc = vfsWrite(ctx, mem.mod, 4, 16, uint32(len(text)), 0) + util.WriteString(mod, 16, text) + rc = vfsWrite(ctx, mod, 4, 16, uint32(len(text)), 0) if rc != _OK { t.Fatal("returned", rc) } // Check file size. - rc = vfsFileSize(ctx, mem.mod, 4, 16) + rc = vfsFileSize(ctx, mod, 4, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(16); got != uint32(len(text)) { + if got := util.ReadUint32(mod, 16); got != uint32(len(text)) { t.Errorf("got %d", got) } // Partial read at offset. - rc = vfsRead(ctx, mem.mod, 4, 16, uint32(len(text)), 4) - if rc != uint32(IOERR_SHORT_READ) { + rc = vfsRead(ctx, mod, 4, 16, uint32(len(text)), 4) + if rc != _IOERR_SHORT_READ { t.Fatal("returned", rc) } - if got := mem.readString(16, 64); got != text[4:] { + if got := util.ReadString(mod, 16, 64); got != text[4:] { t.Errorf("got %q", got) } // Truncate the file. - rc = vfsTruncate(ctx, mem.mod, 4, 4) + rc = vfsTruncate(ctx, mod, 4, 4) if rc != _OK { t.Fatal("returned", rc) } // Check file size. - rc = vfsFileSize(ctx, mem.mod, 4, 16) + rc = vfsFileSize(ctx, mod, 4, 16) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readUint32(16); got != 4 { + if got := util.ReadUint32(mod, 16); got != 4 { t.Errorf("got %d", got) } // Read at offset. - rc = vfsRead(ctx, mem.mod, 4, 32, 4, 0) + rc = vfsRead(ctx, mod, 4, 32, 4, 0) if rc != _OK { t.Fatal("returned", rc) } - if got := mem.readString(32, 64); got != text[:4] { + if got := util.ReadString(mod, 32, 64); got != text[:4] { t.Errorf("got %q", got) } // Close the file. - rc = vfsClose(ctx, mem.mod, 4) + rc = vfsClose(ctx, mod, 4) if rc != _OK { t.Fatal("returned", rc) } diff --git a/mem.go b/mem.go deleted file mode 100644 index 7fb6fa2..0000000 --- a/mem.go +++ /dev/null @@ -1,151 +0,0 @@ -package sqlite3 - -import ( - "bytes" - "math" - - "github.com/tetratelabs/wazero/api" -) - -type memory struct { - mod api.Module -} - -func (m memory) view(ptr uint32, size uint64) []byte { - if ptr == 0 { - panic(nilErr) - } - if size > math.MaxUint32 { - panic(rangeErr) - } - buf, ok := m.mod.Memory().Read(ptr, uint32(size)) - if !ok { - panic(rangeErr) - } - return buf -} - -func (m memory) readUint8(ptr uint32) uint8 { - if ptr == 0 { - panic(nilErr) - } - v, ok := m.mod.Memory().ReadByte(ptr) - if !ok { - panic(rangeErr) - } - return v -} - -func (m memory) writeUint8(ptr uint32, v uint8) { - if ptr == 0 { - panic(nilErr) - } - ok := m.mod.Memory().WriteByte(ptr, v) - if !ok { - panic(rangeErr) - } -} - -func (m memory) readUint32(ptr uint32) uint32 { - if ptr == 0 { - panic(nilErr) - } - v, ok := m.mod.Memory().ReadUint32Le(ptr) - if !ok { - panic(rangeErr) - } - return v -} - -func (m memory) writeUint32(ptr uint32, v uint32) { - if ptr == 0 { - panic(nilErr) - } - ok := m.mod.Memory().WriteUint32Le(ptr, v) - if !ok { - panic(rangeErr) - } -} - -func (m memory) readUint64(ptr uint32) uint64 { - if ptr == 0 { - panic(nilErr) - } - v, ok := m.mod.Memory().ReadUint64Le(ptr) - if !ok { - panic(rangeErr) - } - return v -} - -func (m memory) writeUint64(ptr uint32, v uint64) { - if ptr == 0 { - panic(nilErr) - } - ok := m.mod.Memory().WriteUint64Le(ptr, v) - if !ok { - panic(rangeErr) - } -} - -func (m memory) readBool8(ptr uint32) bool { - v := m.readUint8(ptr) - if v != 0 { - return true - } - return false -} - -func (m memory) writeBool8(ptr uint32, v bool) { - var b uint8 - if v { - b = 1 - } - m.writeUint8(ptr, b) -} - -func (m memory) readFloat64(ptr uint32) float64 { - return math.Float64frombits(m.readUint64(ptr)) -} - -func (m memory) writeFloat64(ptr uint32, v float64) { - m.writeUint64(ptr, math.Float64bits(v)) -} - -func (m memory) readString(ptr, maxlen uint32) string { - if ptr == 0 { - panic(nilErr) - } - switch maxlen { - case 0: - return "" - case math.MaxUint32: - // avoid overflow - default: - maxlen = maxlen + 1 - } - mem := m.mod.Memory() - buf, ok := mem.Read(ptr, maxlen) - if !ok { - buf, ok = mem.Read(ptr, mem.Size()-ptr) - if !ok { - panic(rangeErr) - } - } - if i := bytes.IndexByte(buf, 0); i < 0 { - panic(noNulErr) - } else { - return string(buf[:i]) - } -} - -func (m memory) writeBytes(ptr uint32, b []byte) { - buf := m.view(ptr, uint64(len(b))) - copy(buf, b) -} - -func (m memory) writeString(ptr uint32, s string) { - buf := m.view(ptr, uint64(len(s)+1)) - buf[len(s)] = 0 - copy(buf, s) -} diff --git a/mem_test.go b/mem_test.go deleted file mode 100644 index 54322bf..0000000 --- a/mem_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package sqlite3 - -import ( - "math" - "testing" -) - -func Test_memory_view_nil(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.view(0, 8) - t.Error("want panic") -} - -func Test_memory_view_range(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.view(126, 8) - t.Error("want panic") -} - -func Test_memory_view_overflow(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.view(1, math.MaxInt64) - t.Error("want panic") -} - -func Test_memory_readUint32_nil(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.readUint32(0) - t.Error("want panic") -} - -func Test_memory_readUint32_range(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.readUint32(126) - t.Error("want panic") -} - -func Test_memory_readUint64_nil(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.readUint64(0) - t.Error("want panic") -} - -func Test_memory_readUint64_range(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.readUint64(126) - t.Error("want panic") -} - -func Test_memory_writeUint32_nil(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.writeUint32(0, 1) - t.Error("want panic") -} - -func Test_memory_writeUint32_range(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.writeUint32(126, 1) - t.Error("want panic") -} - -func Test_memory_writeUint64_nil(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.writeUint64(0, 1) - t.Error("want panic") -} - -func Test_memory_writeUint64_range(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.writeUint64(126, 1) - t.Error("want panic") -} - -func Test_memory_readString_range(t *testing.T) { - defer func() { _ = recover() }() - mem := newMemory(128) - mem.readString(130, math.MaxUint32) - t.Error("want panic") -} diff --git a/module.go b/module.go index a3b7a1f..362a407 100644 --- a/module.go +++ b/module.go @@ -8,6 +8,8 @@ import ( "os" "sync" + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/go-sqlite3/internal/vfs" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" ) @@ -50,7 +52,7 @@ func instantiateModule() (*module, error) { func compileModule() { ctx := context.Background() sqlite3.runtime = wazero.NewRuntime(ctx) - vfsInstantiate(ctx, sqlite3.runtime) + vfs.Instantiate(ctx, sqlite3.runtime) bin := Binary if bin == nil && Path != "" { @@ -60,7 +62,7 @@ func compileModule() { } } if bin == nil { - sqlite3.err = binaryErr + sqlite3.err = util.BinaryErr return } @@ -69,20 +71,20 @@ func compileModule() { type module struct { ctx context.Context - mem memory + mod api.Module api sqliteAPI vfs io.Closer } func newModule(mod api.Module) (m *module, err error) { m = &module{} - m.mem = memory{mod} - m.ctx, m.vfs = vfsContext(context.Background()) + m.mod = mod + m.ctx, m.vfs = vfs.Context(context.Background()) getFun := func(name string) api.Function { f := mod.ExportedFunction(name) if f == nil { - err = noFuncErr + errorString(name) + err = util.NoFuncErr + util.ErrorString(name) return nil } return f @@ -91,10 +93,10 @@ func newModule(mod api.Module) (m *module, err error) { getVal := func(name string) uint32 { global := mod.ExportedGlobal(name) if global == nil { - err = noGlobalErr + errorString(name) + err = util.NoGlobalErr + util.ErrorString(name) return 0 } - return m.mem.readUint32(uint32(global.Get())) + return util.ReadUint32(mod, uint32(global.Get())) } m.api = sqliteAPI{ @@ -154,7 +156,7 @@ func newModule(mod api.Module) (m *module, err error) { } func (m *module) close() error { - err := m.mem.mod.Close(m.ctx) + err := m.mod.Close(m.ctx) m.vfs.Close() return err } @@ -167,19 +169,19 @@ func (m *module) error(rc uint64, handle uint32, sql ...string) error { err := Error{code: rc} if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM { - panic(oomErr) + panic(util.OOMErr) } var r []uint64 r = m.call(m.api.errstr, rc) if r != nil { - err.str = m.mem.readString(uint32(r[0]), _MAX_STRING) + err.str = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING) } r = m.call(m.api.errmsg, uint64(handle)) if r != nil { - err.msg = m.mem.readString(uint32(r[0]), _MAX_STRING) + err.msg = util.ReadString(m.mod, uint32(r[0]), _MAX_STRING) } if sql != nil { @@ -215,12 +217,12 @@ func (m *module) free(ptr uint32) { func (m *module) new(size uint64) uint32 { if size > _MAX_ALLOCATION_SIZE { - panic(oomErr) + panic(util.OOMErr) } r := m.call(m.api.malloc, size) ptr := uint32(r[0]) if ptr == 0 && size != 0 { - panic(oomErr) + panic(util.OOMErr) } return ptr } @@ -230,13 +232,13 @@ func (m *module) newBytes(b []byte) uint32 { return 0 } ptr := m.new(uint64(len(b))) - m.mem.writeBytes(ptr, b) + util.WriteBytes(m.mod, ptr, b) return ptr } func (m *module) newString(s string) uint32 { ptr := m.new(uint64(len(s) + 1)) - m.mem.writeString(ptr, s) + util.WriteString(m.mod, ptr, s) return ptr } @@ -286,7 +288,7 @@ func (a *arena) new(size uint64) uint32 { func (a *arena) string(s string) uint32 { ptr := a.new(uint64(len(s) + 1)) - a.m.mem.writeString(ptr, s) + util.WriteString(a.m.mod, ptr, s) return ptr } diff --git a/module_test.go b/module_test.go index a0bcebd..de64b12 100644 --- a/module_test.go +++ b/module_test.go @@ -4,8 +4,14 @@ import ( "bytes" "math" "testing" + + "github.com/ncruces/go-sqlite3/internal/util" ) +func init() { + Path = "./embed/sqlite3.wasm" +} + func TestConn_error_OOM(t *testing.T) { t.Parallel() @@ -71,7 +77,7 @@ func TestConn_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := m.mem.readString(ptr, math.MaxUint32); got != title { + if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != title { t.Errorf("got %q, want %q", got, title) } @@ -80,7 +86,7 @@ func TestConn_newArena(t *testing.T) { if ptr == 0 { t.Fatalf("got nullptr") } - if got := m.mem.readString(ptr, math.MaxUint32); got != body { + if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body { t.Errorf("got %q, want %q", got, body) } arena.free() @@ -107,7 +113,7 @@ func TestConn_newBytes(t *testing.T) { } want := buf - if got := m.mem.view(ptr, uint64(len(want))); !bytes.Equal(got, want) { + if got := util.View(m.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } } @@ -133,7 +139,7 @@ func TestConn_newString(t *testing.T) { } want := str + "\000" - if got := m.mem.view(ptr, uint64(len(want))); string(got) != want { + if got := util.View(m.mod, ptr, uint64(len(want))); string(got) != want { t.Errorf("got %q, want %q", got, want) } } @@ -159,22 +165,22 @@ func TestConn_getString(t *testing.T) { } want := "sqlite3" - if got := m.mem.readString(ptr, math.MaxUint32); got != want { + if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != want { t.Errorf("got %q, want %q", got, want) } - if got := m.mem.readString(ptr, 0); got != "" { + if got := util.ReadString(m.mod, ptr, 0); got != "" { t.Errorf("got %q, want empty", got) } func() { defer func() { _ = recover() }() - m.mem.readString(ptr, uint32(len(want)/2)) + util.ReadString(m.mod, ptr, uint32(len(want)/2)) t.Error("want panic") }() func() { defer func() { _ = recover() }() - m.mem.readString(0, math.MaxUint32) + util.ReadString(m.mod, 0, math.MaxUint32) t.Error("want panic") }() } diff --git a/stmt.go b/stmt.go index 067331d..ef07619 100644 --- a/stmt.go +++ b/stmt.go @@ -3,6 +3,8 @@ package sqlite3 import ( "math" "time" + + "github.com/ncruces/go-sqlite3/internal/util" ) // Stmt is a prepared statement object. @@ -119,7 +121,7 @@ func (s *Stmt) BindName(param int) string { if ptr == 0 { return "" } - return s.c.mem.readString(ptr, _MAX_STRING) + return util.ReadString(s.c.mod, ptr, _MAX_STRING) } // BindBool binds a bool to the prepared statement. @@ -223,7 +225,7 @@ func (s *Stmt) BindTime(param int, value time.Time, format TimeFormat) error { case float64: s.BindFloat(param, v) default: - panic(assertErr()) + panic(util.AssertErr()) } return nil } @@ -247,9 +249,9 @@ func (s *Stmt) ColumnName(col int) string { ptr := uint32(r[0]) if ptr == 0 { - panic(oomErr) + panic(util.OOMErr) } - return s.c.mem.readString(ptr, _MAX_STRING) + return util.ReadString(s.c.mod, ptr, _MAX_STRING) } // ColumnType returns the initial [Datatype] of the result column. @@ -320,7 +322,7 @@ func (s *Stmt) ColumnTime(col int, format TimeFormat) time.Time { case NULL: return time.Time{} default: - panic(assertErr()) + panic(util.AssertErr()) } t, err := format.Decode(v) if err != nil { @@ -366,7 +368,7 @@ func (s *Stmt) ColumnRawText(col int) []byte { r = s.c.call(s.c.api.columnBytes, uint64(s.handle), uint64(col)) - return s.c.mem.view(ptr, r[0]) + return util.View(s.c.mod, ptr, r[0]) } // ColumnRawBlob returns the value of the result column as a []byte. @@ -389,7 +391,7 @@ func (s *Stmt) ColumnRawBlob(col int) []byte { r = s.c.call(s.c.api.columnBytes, uint64(s.handle), uint64(col)) - return s.c.mem.view(ptr, r[0]) + return util.View(s.c.mod, ptr, r[0]) } // Return true if stmt is an empty SQL statement. diff --git a/tests/mptest/mptest_test.go b/tests/mptest/mptest_test.go index 23b18de..2d42d28 100644 --- a/tests/mptest/mptest_test.go +++ b/tests/mptest/mptest_test.go @@ -23,6 +23,7 @@ import ( "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" _ "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/vfs" ) //go:embed testdata/mptest.wasm @@ -31,12 +32,6 @@ var binary []byte //go:embed testdata/*.*test var scripts embed.FS -//go:linkname vfsNewEnvModuleBuilder github.com/ncruces/go-sqlite3.vfsNewEnvModuleBuilder -func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder - -//go:linkname vfsContext github.com/ncruces/go-sqlite3.vfsContext -func vfsContext(ctx context.Context) (context.Context, io.Closer) - var ( rt wazero.Runtime module wazero.CompiledModule @@ -48,7 +43,7 @@ func init() { rt = wazero.NewRuntime(ctx) wasi_snapshot_preview1.MustInstantiate(ctx, rt) - env := vfsNewEnvModuleBuilder(rt) + env := vfs.NewEnvModuleBuilder(rt) env.NewFunctionBuilder().WithFunc(system).Export("system") _, err := env.Instantiate(ctx) if err != nil { @@ -88,7 +83,7 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 { cfg := config(ctx).WithArgs(args...) go func() { - ctx, vfs := vfsContext(ctx) + ctx, vfs := vfs.Context(ctx) rt.InstantiateModule(ctx, module, cfg) vfs.Close() }() @@ -96,7 +91,7 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 { } func Test_config01(t *testing.T) { - ctx, vfs := vfsContext(newContext(t)) + ctx, vfs := vfs.Context(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "config01.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -115,7 +110,7 @@ func Test_config02(t *testing.T) { t.Skip("skipping in CI") } - ctx, vfs := vfsContext(newContext(t)) + ctx, vfs := vfs.Context(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "config02.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -131,7 +126,7 @@ func Test_crash01(t *testing.T) { t.Skip("skipping in short mode") } - ctx, vfs := vfsContext(newContext(t)) + ctx, vfs := vfs.Context(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "crash01.test") mod, err := rt.InstantiateModule(ctx, module, cfg) @@ -147,7 +142,7 @@ func Test_multiwrite01(t *testing.T) { t.Skip("skipping in short mode") } - ctx, vfs := vfsContext(newContext(t)) + ctx, vfs := vfs.Context(newContext(t)) name := filepath.Join(t.TempDir(), "test.db") cfg := config(ctx).WithArgs("mptest", name, "multiwrite01.test") mod, err := rt.InstantiateModule(ctx, module, cfg) diff --git a/tests/speedtest1/speedtest1_test.go b/tests/speedtest1/speedtest1_test.go index 9a3f913..47c3bab 100644 --- a/tests/speedtest1/speedtest1_test.go +++ b/tests/speedtest1/speedtest1_test.go @@ -19,17 +19,12 @@ import ( "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" _ "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/vfs" ) //go:embed testdata/speedtest1.wasm var binary []byte -//go:linkname vfsNewEnvModuleBuilder github.com/ncruces/go-sqlite3.vfsNewEnvModuleBuilder -func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder - -//go:linkname vfsContext github.com/ncruces/go-sqlite3.vfsContext -func vfsContext(ctx context.Context) (context.Context, io.Closer) - var ( rt wazero.Runtime module wazero.CompiledModule @@ -42,7 +37,7 @@ func init() { rt = wazero.NewRuntime(ctx) wasi_snapshot_preview1.MustInstantiate(ctx, rt) - env := vfsNewEnvModuleBuilder(rt) + env := vfs.NewEnvModuleBuilder(rt) _, err := env.Instantiate(ctx) if err != nil { panic(err) @@ -74,7 +69,7 @@ func TestMain(m *testing.M) { func Benchmark_speedtest1(b *testing.B) { output.Reset() - ctx, vfs := vfsContext(context.Background()) + ctx, vfs := vfs.Context(context.Background()) name := filepath.Join(b.TempDir(), "test.db") args := append(options, "--size", strconv.Itoa(b.N), name) cfg := wazero.NewModuleConfig(). diff --git a/time.go b/time.go index 00a9ed9..6fddc3e 100644 --- a/time.go +++ b/time.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/julianday" ) @@ -148,7 +149,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { case int64: return julianday.Time(v, 0), nil default: - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } case TimeFormatUnix, TimeFormatUnixFrac: @@ -167,7 +168,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { case int64: return time.Unix(v, 0), nil default: - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } case TimeFormatUnixMilli: @@ -184,7 +185,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { case int64: return time.UnixMilli(int64(v)), nil default: - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } case TimeFormatUnixMicro: @@ -201,14 +202,14 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { case int64: return time.UnixMicro(int64(v)), nil default: - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } case TimeFormatUnixNano: if s, ok := v.(string); ok { i, err := strconv.ParseInt(s, 10, 64) if err != nil { - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } v = i } @@ -218,7 +219,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { case int64: return time.Unix(0, int64(v)), nil default: - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } // Special formats @@ -288,7 +289,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { } return TimeFormatUnixNano.Decode(v) default: - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } case @@ -300,7 +301,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { TimeFormat7, TimeFormat7TZ: s, ok := v.(string) if !ok { - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } return f.parseRelaxed(s) @@ -310,7 +311,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { TimeFormat10, TimeFormat10TZ: s, ok := v.(string) if !ok { - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } t, err := f.parseRelaxed(s) return t.AddDate(2000, 0, 0), err @@ -318,7 +319,7 @@ func (f TimeFormat) Decode(v any) (time.Time, error) { default: s, ok := v.(string) if !ok { - return time.Time{}, timeErr + return time.Time{}, util.TimeErr } if f == "" { f = time.RFC3339Nano diff --git a/vfs.go b/vfs.go deleted file mode 100644 index de0ff6c..0000000 --- a/vfs.go +++ /dev/null @@ -1,440 +0,0 @@ -package sqlite3 - -import ( - "context" - "crypto/rand" - "errors" - "io" - "io/fs" - "os" - "path/filepath" - "runtime" - "time" - - "github.com/ncruces/julianday" - "github.com/tetratelabs/wazero" - "github.com/tetratelabs/wazero/api" -) - -func vfsInstantiate(ctx context.Context, r wazero.Runtime) { - env := vfsNewEnvModuleBuilder(r) - _, err := env.Instantiate(ctx) - if err != nil { - panic(err) - } -} - -func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder { - env := r.NewHostModuleBuilder("env") - vfsRegisterFuncT(env, "os_localtime", vfsLocaltime) - vfsRegisterFunc3(env, "os_randomness", vfsRandomness) - vfsRegisterFunc2(env, "os_sleep", vfsSleep) - vfsRegisterFunc2(env, "os_current_time", vfsCurrentTime) - vfsRegisterFunc2(env, "os_current_time_64", vfsCurrentTime64) - vfsRegisterFunc4(env, "os_full_pathname", vfsFullPathname) - vfsRegisterFunc3(env, "os_delete", vfsDelete) - vfsRegisterFunc4(env, "os_access", vfsAccess) - vfsRegisterFunc5(env, "os_open", vfsOpen) - vfsRegisterFunc1(env, "os_close", vfsClose) - vfsRegisterFuncRW(env, "os_read", vfsRead) - vfsRegisterFuncRW(env, "os_write", vfsWrite) - vfsRegisterFuncT(env, "os_truncate", vfsTruncate) - vfsRegisterFunc2(env, "os_sync", vfsSync) - vfsRegisterFunc2(env, "os_file_size", vfsFileSize) - vfsRegisterFunc2(env, "os_lock", vfsLock) - vfsRegisterFunc2(env, "os_unlock", vfsUnlock) - vfsRegisterFunc2(env, "os_check_reserved_lock", vfsCheckReservedLock) - vfsRegisterFunc3(env, "os_file_control", vfsFileControl) - return env -} - -// Poor man's namespaces. -const ( - vfsOS vfsOSMethods = false - vfsFile vfsFileMethods = false -) - -type ( - vfsOSMethods bool - vfsFileMethods bool -) - -type vfsKey struct{} -type vfsState struct { - files []*os.File -} - -func vfsContext(ctx context.Context) (context.Context, io.Closer) { - vfs := &vfsState{} - return context.WithValue(ctx, vfsKey{}, vfs), vfs -} - -func (vfs *vfsState) Close() error { - for _, f := range vfs.files { - if f != nil { - f.Close() - } - } - vfs.files = nil - return nil -} - -func vfsLocaltime(ctx context.Context, mod api.Module, pTm uint32, t int64) uint32 { - tm := time.Unix(t, 0) - var isdst int - if tm.IsDST() { - isdst = 1 - } - - // https://pubs.opengroup.org/onlinepubs/7908799/xsh/time.h.html - mem := memory{mod} - mem.writeUint32(pTm+0*ptrlen, uint32(tm.Second())) - mem.writeUint32(pTm+1*ptrlen, uint32(tm.Minute())) - mem.writeUint32(pTm+2*ptrlen, uint32(tm.Hour())) - mem.writeUint32(pTm+3*ptrlen, uint32(tm.Day())) - mem.writeUint32(pTm+4*ptrlen, uint32(tm.Month()-time.January)) - mem.writeUint32(pTm+5*ptrlen, uint32(tm.Year()-1900)) - mem.writeUint32(pTm+6*ptrlen, uint32(tm.Weekday()-time.Sunday)) - mem.writeUint32(pTm+7*ptrlen, uint32(tm.YearDay()-1)) - mem.writeUint32(pTm+8*ptrlen, uint32(isdst)) - return _OK -} - -func vfsRandomness(ctx context.Context, mod api.Module, pVfs, nByte, zByte uint32) uint32 { - mem := memory{mod}.view(zByte, uint64(nByte)) - n, _ := rand.Reader.Read(mem) - return uint32(n) -} - -func vfsSleep(ctx context.Context, mod api.Module, pVfs, nMicro uint32) uint32 { - time.Sleep(time.Duration(nMicro) * time.Microsecond) - return _OK -} - -func vfsCurrentTime(ctx context.Context, mod api.Module, pVfs, prNow uint32) uint32 { - day := julianday.Float(time.Now()) - memory{mod}.writeFloat64(prNow, day) - return _OK -} - -func vfsCurrentTime64(ctx context.Context, mod api.Module, pVfs, piNow uint32) uint32 { - day, nsec := julianday.Date(time.Now()) - msec := day*86_400_000 + nsec/1_000_000 - memory{mod}.writeUint64(piNow, uint64(msec)) - return _OK -} - -func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull, zFull uint32) uint32 { - rel := memory{mod}.readString(zRelative, _MAX_PATHNAME) - abs, err := filepath.Abs(rel) - if err != nil { - return uint32(CANTOPEN_FULLPATH) - } - - size := uint64(len(abs) + 1) - if size > uint64(nFull) { - return uint32(CANTOPEN_FULLPATH) - } - mem := memory{mod}.view(zFull, size) - mem[len(abs)] = 0 - copy(mem, abs) - - if fi, err := os.Lstat(abs); err == nil { - if fi.Mode()&fs.ModeSymlink != 0 { - return _OK_SYMLINK - } - return _OK - } else if errors.Is(err, fs.ErrNotExist) { - return _OK - } - return uint32(CANTOPEN_FULLPATH) -} - -func vfsDelete(ctx context.Context, mod api.Module, pVfs, zPath, syncDir uint32) uint32 { - path := memory{mod}.readString(zPath, _MAX_PATHNAME) - err := os.Remove(path) - if errors.Is(err, fs.ErrNotExist) { - return uint32(IOERR_DELETE_NOENT) - } - if err != nil { - return uint32(IOERR_DELETE) - } - if runtime.GOOS != "windows" && syncDir != 0 { - f, err := os.Open(filepath.Dir(path)) - if err != nil { - return _OK - } - defer f.Close() - err = vfsOS.Sync(f, false, false) - if err != nil { - return uint32(IOERR_DIR_FSYNC) - } - } - return _OK -} - -func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath uint32, flags _AccessFlag, pResOut uint32) uint32 { - path := memory{mod}.readString(zPath, _MAX_PATHNAME) - err := vfsOS.Access(path, flags) - - var res uint32 - var rc xErrorCode - if flags == _ACCESS_EXISTS { - switch { - case err == nil: - res = 1 - case errors.Is(err, fs.ErrNotExist): - res = 0 - default: - rc = IOERR_ACCESS - } - } else { - switch { - case err == nil: - res = 1 - case errors.Is(err, fs.ErrPermission): - res = 0 - default: - rc = IOERR_ACCESS - } - } - - memory{mod}.writeUint32(pResOut, res) - return uint32(rc) -} - -func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, flags OpenFlag, pOutFlags uint32) uint32 { - var oflags int - if flags&OPEN_EXCLUSIVE != 0 { - oflags |= os.O_EXCL - } - if flags&OPEN_CREATE != 0 { - oflags |= os.O_CREATE - } - if flags&OPEN_READONLY != 0 { - oflags |= os.O_RDONLY - } - if flags&OPEN_READWRITE != 0 { - oflags |= os.O_RDWR - } - - var err error - var file *os.File - if zName == 0 { - file, err = os.CreateTemp("", "*.db") - } else { - name := memory{mod}.readString(zName, _MAX_PATHNAME) - file, err = vfsOS.OpenFile(name, oflags, 0666) - } - if err != nil { - return uint32(CANTOPEN) - } - - if flags&OPEN_DELETEONCLOSE != 0 { - os.Remove(file.Name()) - } - - vfsFile.Open(ctx, mod, pFile, file) - - if flags&OPEN_READONLY != 0 { - vfsFile.SetReadOnly(ctx, mod, pFile, true) - } - if runtime.GOOS != "windows" && - flags&(OPEN_CREATE) != 0 && - flags&(OPEN_MAIN_JOURNAL|OPEN_SUPER_JOURNAL|OPEN_WAL) != 0 { - vfsFile.SetSyncDir(ctx, mod, pFile, true) - } - - if pOutFlags != 0 { - memory{mod}.writeUint32(pOutFlags, uint32(flags)) - } - return _OK -} - -func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 { - err := vfsFile.Close(ctx, mod, pFile) - if err != nil { - return uint32(IOERR_CLOSE) - } - return _OK -} - -func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) uint32 { - buf := memory{mod}.view(zBuf, uint64(iAmt)) - - file := vfsFile.GetOS(ctx, mod, pFile) - n, err := file.ReadAt(buf, iOfst) - if n == int(iAmt) { - return _OK - } - if n == 0 && err != io.EOF { - return uint32(IOERR_READ) - } - for i := range buf[n:] { - buf[n+i] = 0 - } - return uint32(IOERR_SHORT_READ) -} - -func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst int64) uint32 { - buf := memory{mod}.view(zBuf, uint64(iAmt)) - - file := vfsFile.GetOS(ctx, mod, pFile) - _, err := file.WriteAt(buf, iOfst) - if err != nil { - return uint32(IOERR_WRITE) - } - return _OK -} - -func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) uint32 { - file := vfsFile.GetOS(ctx, mod, pFile) - err := file.Truncate(nByte) - if err != nil { - return uint32(IOERR_TRUNCATE) - } - return _OK -} - -func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags _SyncFlag) uint32 { - dataonly := (flags & _SYNC_DATAONLY) != 0 - fullsync := (flags & 0x0f) == _SYNC_FULL - - file := vfsFile.GetOS(ctx, mod, pFile) - err := vfsOS.Sync(file, fullsync, dataonly) - if err != nil { - return uint32(IOERR_FSYNC) - } - if runtime.GOOS != "windows" && vfsFile.GetSyncDir(ctx, mod, pFile) { - vfsFile.SetSyncDir(ctx, mod, pFile, false) - f, err := os.Open(filepath.Dir(file.Name())) - if err != nil { - return _OK - } - defer f.Close() - err = vfsOS.Sync(f, false, false) - if err != nil { - return uint32(IOERR_DIR_FSYNC) - } - } - return _OK -} - -func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 { - file := vfsFile.GetOS(ctx, mod, pFile) - off, err := file.Seek(0, io.SeekEnd) - if err != nil { - return uint32(IOERR_SEEK) - } - - memory{mod}.writeUint64(pSize, uint64(off)) - return _OK -} - -func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) uint32 { - switch op { - case _FCNTL_SIZE_HINT: - return vfsSizeHint(ctx, mod, pFile, pArg) - case _FCNTL_HAS_MOVED: - return vfsFileMoved(ctx, mod, pFile, pArg) - } - return uint32(NOTFOUND) -} - -func vfsSizeHint(ctx context.Context, mod api.Module, pFile, pArg uint32) uint32 { - file := vfsFile.GetOS(ctx, mod, pFile) - size := memory{mod}.readUint64(pArg) - err := vfsOS.Allocate(file, int64(size)) - if err != nil { - return uint32(IOERR_TRUNCATE) - } - return _OK -} - -func vfsFileMoved(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 { - file := vfsFile.GetOS(ctx, mod, pFile) - fi, err := file.Stat() - if err != nil { - return uint32(IOERR_FSTAT) - } - pi, err := os.Stat(file.Name()) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - return uint32(IOERR_FSTAT) - } - var res uint32 - if !os.SameFile(fi, pi) { - res = 1 - } - memory{mod}.writeUint32(pResOut, res) - return _OK -} - -func vfsRegisterFunc1(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ uint32) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, uint32(stack[0]))) - }), - []api.ValueType{api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} - -func vfsRegisterFunc2[T0, T1 ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]))) - }), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} - -func vfsRegisterFunc3[T0, T1, T2 ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))) - }), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} - -func vfsRegisterFunc4[T0, T1, T2, T3 ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))) - }), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} - -func vfsRegisterFunc5[T0, T1, T2, T3, T4 ~uint32](mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ T0, _ T1, _ T2, _ T3, _ T4) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))) - }), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} - -func vfsRegisterFuncRW(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _, _, _ uint32, _ int64) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, uint32(stack[0]), uint32(stack[1]), uint32(stack[2]), int64(stack[3]))) - }), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} - -func vfsRegisterFuncT(mod wazero.HostModuleBuilder, name string, fn func(ctx context.Context, mod api.Module, _ uint32, _ int64) uint32) { - mod.NewFunctionBuilder(). - WithGoModuleFunction(api.GoModuleFunc( - func(ctx context.Context, mod api.Module, stack []uint64) { - stack[0] = uint64(fn(ctx, mod, uint32(stack[0]), int64(stack[1]))) - }), - []api.ValueType{api.ValueTypeI32, api.ValueTypeI64}, []api.ValueType{api.ValueTypeI32}). - Export(name) -} diff --git a/vfs_file.go b/vfs_file.go deleted file mode 100644 index 70f02be..0000000 --- a/vfs_file.go +++ /dev/null @@ -1,91 +0,0 @@ -package sqlite3 - -import ( - "context" - "os" - "time" - - "github.com/tetratelabs/wazero/api" -) - -const ( - // These need to match the offsets asserted in os.c - vfsFileIDOffset = 4 - vfsFileLockOffset = 8 - vfsFileSyncDirOffset = 10 - vfsFileReadOnlyOffset = 11 - vfsFileLockTimeoutOffset = 12 -) - -func (vfsFileMethods) NewID(ctx context.Context, file *os.File) uint32 { - vfs := ctx.Value(vfsKey{}).(*vfsState) - - // Find an empty slot. - for id, ptr := range vfs.files { - if ptr == nil { - vfs.files[id] = file - return uint32(id) - } - } - - // Add a new slot. - vfs.files = append(vfs.files, file) - return uint32(len(vfs.files) - 1) -} - -func (vfsFileMethods) Open(ctx context.Context, mod api.Module, pFile uint32, file *os.File) { - mem := memory{mod} - id := vfsFile.NewID(ctx, file) - mem.writeUint32(pFile+vfsFileIDOffset, id) -} - -func (vfsFileMethods) Close(ctx context.Context, mod api.Module, pFile uint32) error { - mem := memory{mod} - id := mem.readUint32(pFile + vfsFileIDOffset) - vfs := ctx.Value(vfsKey{}).(*vfsState) - file := vfs.files[id] - vfs.files[id] = nil - return file.Close() -} - -func (vfsFileMethods) GetOS(ctx context.Context, mod api.Module, pFile uint32) *os.File { - mem := memory{mod} - id := mem.readUint32(pFile + vfsFileIDOffset) - vfs := ctx.Value(vfsKey{}).(*vfsState) - return vfs.files[id] -} - -func (vfsFileMethods) GetLock(ctx context.Context, mod api.Module, pFile uint32) vfsLockState { - mem := memory{mod} - return vfsLockState(mem.readUint8(pFile + vfsFileLockOffset)) -} - -func (vfsFileMethods) SetLock(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) { - mem := memory{mod} - mem.writeUint8(pFile+vfsFileLockOffset, uint8(lock)) -} - -func (vfsFileMethods) GetLockTimeout(ctx context.Context, mod api.Module, pFile uint32) time.Duration { - mem := memory{mod} - return time.Duration(mem.readUint32(pFile+vfsFileLockTimeoutOffset)) * time.Millisecond -} - -func (vfsFileMethods) GetSyncDir(ctx context.Context, mod api.Module, pFile uint32) bool { - mem := memory{mod} - return mem.readBool8(pFile + vfsFileSyncDirOffset) -} - -func (vfsFileMethods) SetSyncDir(ctx context.Context, mod api.Module, pFile uint32, val bool) { - mem := memory{mod} - mem.writeBool8(pFile+vfsFileSyncDirOffset, val) -} - -func (vfsFileMethods) GetReadOnly(ctx context.Context, mod api.Module, pFile uint32) bool { - mem := memory{mod} - return mem.readBool8(pFile + vfsFileReadOnlyOffset) -} - -func (vfsFileMethods) SetReadOnly(ctx context.Context, mod api.Module, pFile uint32, val bool) { - mem := memory{mod} - mem.writeBool8(pFile+vfsFileReadOnlyOffset, val) -} diff --git a/vfs_os_bsd.go b/vfs_os_bsd.go deleted file mode 100644 index 5a7f96e..0000000 --- a/vfs_os_bsd.go +++ /dev/null @@ -1,56 +0,0 @@ -//go:build freebsd || openbsd || netbsd || dragonfly || (darwin && sqlite3_bsd) - -package sqlite3 - -import ( - "os" - "time" - - "golang.org/x/sys/unix" -) - -func (vfsOSMethods) unlock(file *os.File, start, len int64) xErrorCode { - if start == 0 && len == 0 { - err := unix.Flock(int(file.Fd()), unix.LOCK_UN) - if err != nil { - return IOERR_UNLOCK - } - } - return _OK -} - -func (vfsOSMethods) lock(file *os.File, how int, timeout time.Duration, def xErrorCode) xErrorCode { - var err error - for { - err = unix.Flock(int(file.Fd()), how) - if errno, _ := err.(unix.Errno); errno != unix.EAGAIN { - break - } - if timeout < time.Millisecond { - break - } - timeout -= time.Millisecond - time.Sleep(time.Millisecond) - } - return vfsOS.lockErrorCode(err, def) -} - -func (vfsOSMethods) readLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, unix.LOCK_SH|unix.LOCK_NB, timeout, IOERR_RDLOCK) -} - -func (vfsOSMethods) writeLock(file *os.File, start, len int64, timeout time.Duration) xErrorCode { - return vfsOS.lock(file, unix.LOCK_EX|unix.LOCK_NB, timeout, IOERR_LOCK) -} - -func (vfsOSMethods) checkLock(file *os.File, start, len int64) (bool, xErrorCode) { - lock := unix.Flock_t{ - Type: unix.F_RDLCK, - Start: start, - Len: len, - } - if unix.FcntlFlock(file.Fd(), unix.F_GETLK, &lock) != nil { - return false, IOERR_CHECKRESERVEDLOCK - } - return lock.Type != unix.F_UNLCK, _OK -}