diff --git a/vfs.go b/vfs.go index ad5de95..fa24827 100644 --- a/vfs.go +++ b/vfs.go @@ -57,9 +57,36 @@ func vfsNewEnvModuleBuilder(r wazero.Runtime) wazero.HostModuleBuilder { return env } -type vfsOSMethods bool +// Poor man's namespaces. +const ( + vfsOS vfsOSMethods = false + vfsFile vfsFileMethods = false +) -const vfsOS vfsOSMethods = false +type ( + vfsOSMethods bool + vfsFileMethods bool +) + +type vfsKey struct{} +type vfsState struct { + files []*os.File +} + +func vfsContext(ctx context.Context) (context.Context, io.Closer) { + vfs := &vfsState{} + return context.WithValue(ctx, vfsKey{}, vfs), vfs +} + +func (vfs *vfsState) Close() error { + for _, f := range vfs.files { + if f != nil { + f.Close() + } + } + vfs.files = nil + return nil +} func vfsExit(ctx context.Context, mod api.Module, exitCode uint32) { // Ensure other callers see the exit code. @@ -232,7 +259,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla vfsOS.DeleteOnClose(file) } - vfsFileOpen(ctx, mod, pFile, file) + vfsFile.Open(ctx, mod, pFile, file) if pOutFlags != 0 { memory{mod}.writeUint32(pOutFlags, uint32(flags)) @@ -241,7 +268,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile uint32, fla } func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 { - err := vfsFileClose(ctx, mod, pFile) + err := vfsFile.Close(ctx, mod, pFile) if err != nil { return uint32(IOERR_CLOSE) } @@ -251,7 +278,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 { func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { buf := memory{mod}.view(zBuf, uint64(iAmt)) - file := vfsFileGet(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) n, err := file.ReadAt(buf, int64(iOfst)) if n == int(iAmt) { return _OK @@ -268,7 +295,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { buf := memory{mod}.view(zBuf, uint64(iAmt)) - file := vfsFileGet(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) _, err := file.WriteAt(buf, int64(iOfst)) if err != nil { return uint32(IOERR_WRITE) @@ -277,7 +304,7 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOf } func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 { - file := vfsFileGet(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) err := file.Truncate(int64(nByte)) if err != nil { return uint32(IOERR_TRUNCATE) @@ -286,7 +313,7 @@ func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64 } func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 { - file := vfsFileGet(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) err := file.Sync() if err != nil { return uint32(IOERR_FSYNC) @@ -298,7 +325,7 @@ func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint3 // This uses [os.File.Seek] because we don't care about the offset for reading/writing. // But consider using [os.File.Stat] instead (as other VFSes do). - file := vfsFileGet(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) off, err := file.Seek(0, io.SeekEnd) if err != nil { return uint32(IOERR_SEEK) diff --git a/vfs_files.go b/vfs_files.go index f5b6d27..6503510 100644 --- a/vfs_files.go +++ b/vfs_files.go @@ -2,34 +2,12 @@ package sqlite3 import ( "context" - "io" "os" "github.com/tetratelabs/wazero/api" ) -type vfsKey struct{} - -type vfsState struct { - files []*os.File -} - -func (vfs *vfsState) Close() error { - for _, f := range vfs.files { - if f != nil { - f.Close() - } - } - vfs.files = nil - return nil -} - -func vfsContext(ctx context.Context) (context.Context, io.Closer) { - vfs := &vfsState{} - return context.WithValue(ctx, vfsKey{}, vfs), vfs -} - -func vfsFileNewID(ctx context.Context, file *os.File) uint32 { +func (vfsFileMethods) NewID(ctx context.Context, file *os.File) uint32 { vfs := ctx.Value(vfsKey{}).(*vfsState) // Find an empty slot. @@ -45,14 +23,14 @@ func vfsFileNewID(ctx context.Context, file *os.File) uint32 { return uint32(len(vfs.files) - 1) } -func vfsFileOpen(ctx context.Context, mod api.Module, pFile uint32, file *os.File) { +func (vfsFileMethods) Open(ctx context.Context, mod api.Module, pFile uint32, file *os.File) { mem := memory{mod} - id := vfsFileNewID(ctx, file) + id := vfsFile.NewID(ctx, file) mem.writeUint32(pFile+ptrlen, id) mem.writeUint32(pFile+2*ptrlen, _NO_LOCK) } -func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error { +func (vfsFileMethods) Close(ctx context.Context, mod api.Module, pFile uint32) error { mem := memory{mod} id := mem.readUint32(pFile + ptrlen) vfs := ctx.Value(vfsKey{}).(*vfsState) @@ -61,19 +39,19 @@ func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error { return file.Close() } -func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) *os.File { +func (vfsFileMethods) GetOS(ctx context.Context, mod api.Module, pFile uint32) *os.File { mem := memory{mod} id := mem.readUint32(pFile + ptrlen) vfs := ctx.Value(vfsKey{}).(*vfsState) return vfs.files[id] } -func vfsFileLockState(ctx context.Context, mod api.Module, pFile uint32) vfsLockState { +func (vfsFileMethods) GetLock(ctx context.Context, mod api.Module, pFile uint32) vfsLockState { mem := memory{mod} return vfsLockState(mem.readUint32(pFile + 2*ptrlen)) } -func vfsFileSetLockState(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) { +func (vfsFileMethods) SetLock(ctx context.Context, mod api.Module, pFile uint32, lock vfsLockState) { mem := memory{mod} mem.writeUint32(pFile+2*ptrlen, uint32(lock)) } diff --git a/vfs_lock.go b/vfs_lock.go index 69d2cfe..7a66a21 100644 --- a/vfs_lock.go +++ b/vfs_lock.go @@ -61,8 +61,8 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta panic(assertErr()) } - file := vfsFileGet(ctx, mod, pFile) - cLock := vfsFileLockState(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) + cLock := vfsFile.GetLock(ctx, mod, pFile) switch { case cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK: @@ -94,7 +94,7 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta if rc := vfsOS.GetSharedLock(file); rc != _OK { return uint32(rc) } - vfsFileSetLockState(ctx, mod, pFile, _SHARED_LOCK) + vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK) return _OK case _RESERVED_LOCK: @@ -105,7 +105,7 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta if rc := vfsOS.GetReservedLock(file); rc != _OK { return uint32(rc) } - vfsFileSetLockState(ctx, mod, pFile, _RESERVED_LOCK) + vfsFile.SetLock(ctx, mod, pFile, _RESERVED_LOCK) return _OK case _EXCLUSIVE_LOCK: @@ -118,12 +118,12 @@ func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockSta if rc := vfsOS.GetPendingLock(file); rc != _OK { return uint32(rc) } - vfsFileSetLockState(ctx, mod, pFile, _PENDING_LOCK) + vfsFile.SetLock(ctx, mod, pFile, _PENDING_LOCK) } if rc := vfsOS.GetExclusiveLock(file); rc != _OK { return uint32(rc) } - vfsFileSetLockState(ctx, mod, pFile, _EXCLUSIVE_LOCK) + vfsFile.SetLock(ctx, mod, pFile, _EXCLUSIVE_LOCK) return _OK default: @@ -137,8 +137,8 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS panic(assertErr()) } - file := vfsFileGet(ctx, mod, pFile) - cLock := vfsFileLockState(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) + cLock := vfsFile.GetLock(ctx, mod, pFile) // Connection state check. if cLock < _NO_LOCK || cLock > _EXCLUSIVE_LOCK { @@ -155,12 +155,12 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS if rc := vfsOS.DowngradeLock(file, cLock); rc != _OK { return uint32(rc) } - vfsFileSetLockState(ctx, mod, pFile, _SHARED_LOCK) + vfsFile.SetLock(ctx, mod, pFile, _SHARED_LOCK) return _OK case _NO_LOCK: rc := vfsOS.ReleaseLock(file, cLock) - vfsFileSetLockState(ctx, mod, pFile, _NO_LOCK) + vfsFile.SetLock(ctx, mod, pFile, _NO_LOCK) return uint32(rc) default: @@ -169,13 +169,13 @@ func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock vfsLockS } func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 { - cLock := vfsFileLockState(ctx, mod, pFile) + cLock := vfsFile.GetLock(ctx, mod, pFile) if cLock > _SHARED_LOCK { panic(assertErr()) } - file := vfsFileGet(ctx, mod, pFile) + file := vfsFile.GetOS(ctx, mod, pFile) locked, rc := vfsOS.CheckReservedLock(file) var res uint32 diff --git a/vfs_lock_test.go b/vfs_lock_test.go index 7914991..88d463f 100644 --- a/vfs_lock_test.go +++ b/vfs_lock_test.go @@ -41,8 +41,8 @@ func Test_vfsLock(t *testing.T) { ctx, vfs := vfsContext(context.TODO()) defer vfs.Close() - vfsFileOpen(ctx, mem.mod, pFile1, file1) - vfsFileOpen(ctx, mem.mod, pFile2, file2) + vfsFile.Open(ctx, mem.mod, pFile1, file1) + vfsFile.Open(ctx, mem.mod, pFile2, file2) rc := vfsCheckReservedLock(ctx, mem.mod, pFile1, pOutput) if rc != _OK {