diff --git a/conn.go b/conn.go index d04a518..df5816e 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io/fs" + "os" "path/filepath" "strconv" @@ -18,6 +19,7 @@ type Conn struct { module api.Module memory api.Memory api sqliteAPI + files []*os.File } func Open(name string) (conn *Conn, err error) { @@ -82,6 +84,11 @@ func (c *Conn) Close() error { if err := c.error(r[0]); err != nil { return err } + for _, f := range c.files { + if f != nil { + f.Close() + } + } return c.module.Close(c.ctx) } diff --git a/sqlite3/main.c b/sqlite3/main.c index 98a50e0..272081a 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -14,16 +14,65 @@ int go_current_time_64(sqlite3_vfs *, sqlite3_int64 *); int go_open(sqlite3_vfs *, sqlite3_filename zName, sqlite3_file *, int flags, int *pOutFlags); +int go_delete(sqlite3_vfs *, const char *zName, int syncDir); +int go_access(sqlite3_vfs *, const char *zName, int flags, int *pResOut); int go_full_pathname(sqlite3_vfs *, const char *zName, int nOut, char *zOut); +struct go_file { + sqlite3_file base; + int fd; +}; + +int go_close(sqlite3_file *); +int go_read(sqlite3_file *, void *, int iAmt, sqlite3_int64 iOfst); +int go_write(sqlite3_file *, const void *, int iAmt, sqlite3_int64 iOfst); +int go_truncate(sqlite3_file *, sqlite3_int64 size); +int go_sync(sqlite3_file *, int flags); +int go_file_size(sqlite3_file *, sqlite3_int64 *pSize); + +static int no_lock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; } +static int no_unlock(sqlite3_file *pFile, int eLock) { return SQLITE_OK; } +static int no_check_reserved_lock(sqlite3_file *pFile, int *pResOut) { + *pResOut = 0; + return SQLITE_OK; +} +static int no_file_control(sqlite3_file *pFile, int op, void *pArg) { + return SQLITE_NOTFOUND; +} +static int no_sector_size(sqlite3_file *pFile) { return 0; } +static int no_device_characteristics(sqlite3_file *pFile) { return 0; } + +static int go_open_c(sqlite3_vfs *vfs, sqlite3_filename zName, + sqlite3_file *file, int flags, int *pOutFlags) { + static const sqlite3_io_methods go_io = { + .iVersion = 1, + .xClose = go_close, + .xRead = go_read, + .xWrite = go_write, + .xTruncate = go_truncate, + .xSync = go_sync, + .xFileSize = go_file_size, + .xLock = no_lock, + .xUnlock = no_unlock, + .xCheckReservedLock = no_check_reserved_lock, + .xFileControl = no_file_control, + .xSectorSize = no_sector_size, + .xDeviceCharacteristics = no_device_characteristics, + }; + file->pMethods = &go_io; + return go_open(vfs, zName, file, flags, pOutFlags); +} + int sqlite3_os_init() { static sqlite3_vfs go_vfs = { .iVersion = 2, - .szOsFile = sizeof(sqlite3_file), + .szOsFile = sizeof(struct go_file), .mxPathname = 512, .zName = "go", - .xOpen = go_open, + .xOpen = go_open_c, + .xDelete = go_delete, + .xAccess = go_access, .xFullPathname = go_full_pathname, .xRandomness = go_randomness, diff --git a/vfs.go b/vfs.go index 428f284..95736fb 100644 --- a/vfs.go +++ b/vfs.go @@ -2,8 +2,8 @@ package sqlite3 import ( "context" - "log" "math/rand" + "os" "path/filepath" "time" @@ -22,12 +22,20 @@ func vfsInstantiate(ctx context.Context, r wazero.Runtime) (err error) { } env := r.NewHostModuleBuilder("env") - env.NewFunctionBuilder().WithFunc(vfsOpen).Export("go_open") - env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("go_full_pathname") env.NewFunctionBuilder().WithFunc(vfsRandomness).Export("go_randomness") env.NewFunctionBuilder().WithFunc(vfsSleep).Export("go_sleep") env.NewFunctionBuilder().WithFunc(vfsCurrentTime).Export("go_current_time") env.NewFunctionBuilder().WithFunc(vfsCurrentTime64).Export("go_current_time_64") + env.NewFunctionBuilder().WithFunc(vfsFullPathname).Export("go_full_pathname") + env.NewFunctionBuilder().WithFunc(vfsDelete).Export("go_delete") + env.NewFunctionBuilder().WithFunc(vfsAccess).Export("go_access") + env.NewFunctionBuilder().WithFunc(vfsOpen).Export("go_open") + env.NewFunctionBuilder().WithFunc(vfsClose).Export("go_close") + env.NewFunctionBuilder().WithFunc(vfsRead).Export("go_read") + env.NewFunctionBuilder().WithFunc(vfsWrite).Export("go_write") + env.NewFunctionBuilder().WithFunc(vfsTruncate).Export("go_truncate") + env.NewFunctionBuilder().WithFunc(vfsSync).Export("go_sync") + env.NewFunctionBuilder().WithFunc(vfsFileSize).Export("go_file_size") _, err = env.Instantiate(ctx) return err } @@ -72,15 +80,6 @@ func vfsCurrentTime64(ctx context.Context, mod api.Module, vfs, out uint32) uint return _OK } -func vfsOpen(ctx context.Context, mod api.Module, vfs, zName, file, flags, pOutFlags uint32) uint32 { - name := getString(mod.Memory(), zName, _MAX_PATHNAME) - c, ok := ctx.Value(connContext{}).(*Conn) - if ok && mod == c.module { - log.Println("vfsOpen", name) - } - return uint32(IOERR) -} - func vfsFullPathname(ctx context.Context, mod api.Module, vfs, zName, nOut, zOut uint32) uint32 { name := getString(mod.Memory(), zName, _MAX_PATHNAME) s, err := filepath.Abs(name) @@ -101,3 +100,72 @@ func vfsFullPathname(ctx context.Context, mod api.Module, vfs, zName, nOut, zOut copy(mem, s) return _OK } + +func vfsDelete(vfs, zName, syncDir uint32) uint32 { panic("vfsDelete") } + +func vfsAccess(vfs, zName, flags, pResOut uint32) uint32 { panic("vfsAccess") } + +func vfsOpen(ctx context.Context, mod api.Module, vfs, zName, file, flags, pOutFlags uint32) uint32 { + name := getString(mod.Memory(), zName, _MAX_PATHNAME) + c := ctx.Value(connContext{}).(*Conn) + + var oflags int + if OpenFlag(flags)&OPEN_EXCLUSIVE != 0 { + oflags |= os.O_EXCL + } + if OpenFlag(flags)&OPEN_CREATE != 0 { + oflags |= os.O_CREATE + } + if OpenFlag(flags)&OPEN_READONLY != 0 { + oflags |= os.O_RDONLY + } + if OpenFlag(flags)&OPEN_READWRITE != 0 { + oflags |= os.O_RDWR + } + f, err := os.OpenFile(name, oflags, 0600) + if err != nil { + return uint32(CANTOPEN) + } + + var id int + for i := range c.files { + if c.files[i] == nil { + id = i + c.files[i] = f + goto found + } + } + id = len(c.files) + c.files = append(c.files, f) +found: + + mod.Memory().WriteUint32Le(file+ptrSize, uint32(id)) + return _OK +} + +func vfsClose(ctx context.Context, mod api.Module, file uint32) uint32 { + id, ok := mod.Memory().ReadUint32Le(file + ptrSize) + if !ok { + panic("sqlite: out-of-range") + } + + c := ctx.Value(connContext{}).(*Conn) + err := c.files[id].Close() + c.files[id] = nil + if err != nil { + return uint32(IOERR) + } + return _OK +} + +func vfsRead(file, buf, iAmt uint32, iOfst uint64) uint32 { + return uint32(IOERR) +} + +func vfsWrite(file, buf, iAmt uint32, iOfst uint64) uint32 { panic("vfsWrite") } + +func vfsTruncate(file uint32, size uint64) uint32 { panic("vfsTruncate") } + +func vfsSync(file, flags uint32) uint32 { panic("vfsSync") } + +func vfsFileSize(file, pSize uint32) uint32 { panic("vfsFileSize") }