diff --git a/conn.go b/conn.go index e201ae4..0b6b3fa 100644 --- a/conn.go +++ b/conn.go @@ -3,7 +3,6 @@ package sqlite3 import ( "bytes" "context" - "os" "strconv" "github.com/tetratelabs/wazero" @@ -16,7 +15,6 @@ type Conn struct { module api.Module memory api.Memory api sqliteAPI - files []*os.File } func Open(filename string) (conn *Conn, err error) { @@ -40,7 +38,7 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { }() c := newConn(module) - c.ctx = context.WithValue(ctx, connContext{}, c) + c.ctx = context.Background() namePtr := c.newString(filename) connPtr := c.new(ptrSize) defer c.free(namePtr) @@ -68,11 +66,6 @@ func (c *Conn) Close() error { if err := c.error(r[0]); err != nil { return err } - for _, f := range c.files { - if f != nil { - f.Close() - } - } return c.module.Close(c.ctx) } @@ -225,17 +218,4 @@ func getString(memory api.Memory, ptr, maxlen uint32) string { } } -func (c *Conn) getFile(f *os.File) uint32 { - for i := range c.files { - if c.files[i] == nil { - c.files[i] = f - return uint32(i) - } - } - c.files = append(c.files, f) - return uint32(len(c.files) - 1) -} - -type connContext struct{} - const ptrSize = 4 diff --git a/sqlite3/main.c b/sqlite3/main.c index eb2e532..2f6d24c 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -21,6 +21,7 @@ int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut); struct go_file { sqlite3_file base; int id; + int eLock; }; int go_close(sqlite3_file *); @@ -30,12 +31,17 @@ int go_truncate(sqlite3_file *, sqlite3_int64 size); int go_sync(sqlite3_file *, int flags); int go_file_size(sqlite3_file *, sqlite3_int64 *pSize); +int go_lock(sqlite3_file *pFile, int eLock); +int go_unlock(sqlite3_file *pFile, int eLock); +int go_check_reserved_lock(sqlite3_file *pFile, int *pResOut); + static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; } static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; } static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) { *pResOut = 0; return SQLITE_OK; } + static int no_file_control(sqlite3_file *pFile, int op, void *pArg) { return SQLITE_NOTFOUND; } @@ -52,9 +58,9 @@ static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName, .xTruncate = go_truncate, .xSync = go_sync, .xFileSize = go_file_size, - .xLock = no_lock, - .xUnlock = no_unlock, - .xCheckReservedLock = no_check_reserved_lock, + .xLock = go_lock, + .xUnlock = go_unlock, + .xCheckReservedLock = go_check_reserved_lock, .xFileControl = no_file_control, .xSectorSize = no_sector_size, .xDeviceCharacteristics = no_device_characteristics, diff --git a/vfs.go b/vfs.go index 1ef97ea..174023e 100644 --- a/vfs.go +++ b/vfs.go @@ -40,6 +40,9 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) { env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("go_truncate") env.NewFunctionBuilder().WithFunc(vfsSync).Export("go_sync") env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("go_file_size") + env.NewFunctionBuilder().WithFunc(vfsLock).Export("go_lock") + env.NewFunctionBuilder().WithFunc(vfsUnlock).Export("go_unlock") + env.NewFunctionBuilder().WithFunc(vfsCheckReservedLock).Export("go_check_reserved_lock") _, err = env.Instantiate(ctx) return err } @@ -89,6 +92,10 @@ func vfsFullPathname(ctx context.Context, mod api.Module, pVfs, zRelative, nFull return uint32(IOERR) } + // Consider either using [filepath.EvalSymlinks] to canonicalize the path (as the Unix VFS does). + // Or using [os.Readlink] to resolve a symbolic link (as the Unix VFS did). + // This might be buggy on Windows (the Windows VFS doesn't try). + siz := uint32(len(abs) + 1) if siz > nFull { return uint32(IOERR) @@ -166,7 +173,6 @@ func vfsAccess(ctx context.Context, mod api.Module, pVfs, zPath, flags, pResOut func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile, flags, pOutFlags uint32) uint32 { name := getString(mod.Memory(), zName, _MAX_PATHNAME) - c := ctx.Value(connContext{}).(*Conn) var oflags int if OpenFlag(flags)&OPEN_EXCLUSIVE != 0 { @@ -181,14 +187,17 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile, flags, pOu if OpenFlag(flags)&OPEN_READWRITE != 0 { oflags |= os.O_RDWR } - f, err := os.OpenFile(name, oflags, 0600) + file, err := os.OpenFile(name, oflags, 0600) if err != nil { return uint32(CANTOPEN) } - if ok := mod.Memory().WriteUint32Le(pFile+ptrSize, c.getFile(f)); !ok { - panic(rangeErr) + id, err := vfsGetFileID(file) + if err != nil { + return uint32(CANTOPEN) } + vfsSetFileData(mod, pFile, id, _NO_LOCK) + if pOutFlags == 0 { return _OK } @@ -199,37 +208,20 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zName, pFile, flags, pOu } func vfsClose(ctx context.Context, mod api.Module, pFile uint32) uint32 { - id, ok := mod.Memory().ReadUint32Le(pFile + ptrSize) - if !ok { - panic(rangeErr) - } - - c := ctx.Value(connContext{}).(*Conn) - err := c.files[id].Close() - c.files[id] = nil + err := vfsReleaseFile(mod, pFile) if err != nil { return uint32(IOERR_CLOSE) } return _OK } -func vfsFile(ctx context.Context, mod api.Module, pFile uint32) *os.File { - id, ok := mod.Memory().ReadUint32Le(pFile + ptrSize) - if !ok { - panic(rangeErr) - } - - c := ctx.Value(connContext{}).(*Conn) - return c.files[id] -} - func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfst uint64) uint32 { mem, ok := mod.Memory().Read(zBuf, iAmt) if !ok { panic(rangeErr) } - file := vfsFile(ctx, mod, pFile) + file := vfsGetOSFile(mod, pFile) n, err := file.ReadAt(mem, int64(iOfst)) if n == int(iAmt) { return _OK @@ -249,7 +241,7 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOf panic(rangeErr) } - file := vfsFile(ctx, mod, pFile) + file := vfsGetOSFile(mod, pFile) _, err := file.WriteAt(mem, int64(iOfst)) if err != nil { return uint32(IOERR_WRITE) @@ -258,7 +250,7 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOf } func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64) uint32 { - file := vfsFile(ctx, mod, pFile) + file := vfsGetOSFile(mod, pFile) err := file.Truncate(int64(nByte)) if err != nil { return uint32(IOERR_TRUNCATE) @@ -267,7 +259,7 @@ func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte uint64 } func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 { - file := vfsFile(ctx, mod, pFile) + file := vfsGetOSFile(mod, pFile) err := file.Sync() if err != nil { return uint32(IOERR_FSYNC) @@ -276,7 +268,9 @@ func vfsSync(ctx context.Context, mod api.Module, pFile, flags uint32) uint32 { } func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) uint32 { - file := vfsFile(ctx, mod, pFile) + // This uses [file.Seek] because we don't care about the offset for reading/writing. + // But consider using [file.Stat] instead (as other VFSes do). + file := vfsGetOSFile(mod, pFile) off, err := file.Seek(0, io.SeekEnd) if err != nil { return uint32(IOERR_SEEK) diff --git a/vfs_files.go b/vfs_files.go new file mode 100644 index 0000000..73447e6 --- /dev/null +++ b/vfs_files.go @@ -0,0 +1,117 @@ +package sqlite3 + +import ( + "os" + "sync" + + "github.com/tetratelabs/wazero/api" +) + +type vfsOpenFile struct { + file *os.File + info os.FileInfo + nref int + lock int +} + +var ( + vfsMutex sync.Mutex + vfsOpenFiles []*vfsOpenFile +) + +func vfsGetFileID(file *os.File) (uint32, error) { + fi, err := file.Stat() + if err != nil { + return 0, err + } + + vfsMutex.Lock() + defer vfsMutex.Unlock() + + // Reuse an already opened file. + for id, of := range vfsOpenFiles { + if of == nil { + continue + } + if os.SameFile(fi, of.info) { + if err := file.Close(); err != nil { + return 0, err + } + of.nref++ + return uint32(id), nil + } + } + + openFile := vfsOpenFile{ + file: file, + info: fi, + nref: 1, + } + + // Find an empty slot. + for id, of := range vfsOpenFiles { + if of == nil { + vfsOpenFiles[id] = &openFile + return uint32(id), nil + } + } + + // Add a new slot. + id := len(vfsOpenFiles) + vfsOpenFiles = append(vfsOpenFiles, &openFile) + return uint32(id), nil +} + +func vfsReleaseFile(mod api.Module, pFile uint32) error { + id, ok := mod.Memory().ReadUint32Le(pFile + ptrSize) + if !ok { + panic(rangeErr) + } + + vfsMutex.Lock() + defer vfsMutex.Unlock() + + of := vfsOpenFiles[id] + if of.nref--; of.nref > 0 { + return nil + } + err := of.file.Close() + vfsOpenFiles[id] = nil + return err +} + +func vfsGetOSFile(mod api.Module, pFile uint32) *os.File { + id, ok := mod.Memory().ReadUint32Le(pFile + ptrSize) + if !ok { + panic(rangeErr) + } + return vfsOpenFiles[id].file +} + +func vfsGetFileData(mod api.Module, pFile uint32) (id, lock uint32) { + var ok bool + if id, ok = mod.Memory().ReadUint32Le(pFile + ptrSize); !ok { + panic(rangeErr) + } + if lock, ok = mod.Memory().ReadUint32Le(pFile + 2*ptrSize); !ok { + panic(rangeErr) + } + return +} + +func vfsSetFileData(mod api.Module, pFile, id, lock uint32) { + if ok := mod.Memory().WriteUint32Le(pFile+ptrSize, id); !ok { + panic(rangeErr) + } + if ok := mod.Memory().WriteUint32Le(pFile+2*ptrSize, lock); !ok { + panic(rangeErr) + } +} + +const ( + _NO_LOCK = 0 + _SHARED_LOCK = 1 + _RESERVED_LOCK = 2 + _PENDING_LOCK = 3 + _EXCLUSIVE_LOCK = 4 +) diff --git a/vfs_unix.go b/vfs_unix.go new file mode 100644 index 0000000..b7d7468 --- /dev/null +++ b/vfs_unix.go @@ -0,0 +1,22 @@ +//go:build unix + +package sqlite3 + +import ( + "context" + + "github.com/tetratelabs/wazero/api" +) + +func vfsLock(ctx context.Context, mod api.Module, pFile, eLock uint32) uint32 { + return _OK +} + +func vfsUnlock(ctx context.Context, mod api.Module, pFile, eLock uint32) uint32 { + return _OK +} + +func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 { + mod.Memory().WriteUint32Le(pResOut, 0) + return _OK +} diff --git a/vfs_windows.go b/vfs_windows.go new file mode 100644 index 0000000..a979332 --- /dev/null +++ b/vfs_windows.go @@ -0,0 +1,20 @@ +package sqlite3 + +import ( + "context" + + "github.com/tetratelabs/wazero/api" +) + +func vfsLock(ctx context.Context, pFile, eLock uint32) uint32 { + return _OK +} + +func vfsUnlock(ctx context.Context, pFile, eLock uint32) uint32 { + return _OK +} + +func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) uint32 { + mod.Memory().WriteUint32Le(pResOut, 0) + return _OK +}