From 62b79d2ac3891f8aa9fec2b1cc7530914e74a03e Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 19 Apr 2024 18:51:27 +0100 Subject: [PATCH] Shared memory API. --- internal/util/unwrap.go | 8 -------- vfs/adiantum/hbsh.go | 33 ++++++++++++++++++++------------- vfs/api.go | 22 +++++++++++++++++++++- vfs/file.go | 8 -------- vfs/shm.go | 2 ++ vfs/vfs.go | 12 +++++++----- 6 files changed, 50 insertions(+), 35 deletions(-) delete mode 100644 internal/util/unwrap.go diff --git a/internal/util/unwrap.go b/internal/util/unwrap.go deleted file mode 100644 index 1cb09a2..0000000 --- a/internal/util/unwrap.go +++ /dev/null @@ -1,8 +0,0 @@ -package util - -func Unwrap[T any](v T) T { - if u, ok := any(v).(interface{ Unwrap() T }); ok { - return u.Unwrap() - } - return v -} diff --git a/vfs/adiantum/hbsh.go b/vfs/adiantum/hbsh.go index ec8605b..9d96508 100644 --- a/vfs/adiantum/hbsh.go +++ b/vfs/adiantum/hbsh.go @@ -27,18 +27,19 @@ func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values) // Encrypt everything except super journals. if flags&vfs.OPEN_SUPER_JOURNAL == 0 { var key []byte - if t, ok := params["key"]; ok { + if name == "" { + key = h.hbsh.KDF("") // Temporary files get a random key. + } else if t, ok := params["key"]; ok { key = []byte(t[0]) } else if t, ok := params["hexkey"]; ok { key, _ = hex.DecodeString(t[0]) } else if t, ok := params["textkey"]; ok { key = h.hbsh.KDF(t[0]) - } else if name == "" { - key = h.hbsh.KDF("") } if hbsh = h.hbsh.HBSH(key); hbsh == nil { - return nil, flags, sqlite3.NOTADB + // Can't open without a valid key. + return nil, flags, sqlite3.CANTOPEN } } @@ -51,6 +52,7 @@ func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values) file, flags, err = h.Open(name, flags) } if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 { + // Error, or no encryption (super journals, memory files). return file, flags, err } return &hbshFile{File: file, hbsh: hbsh}, flags, err @@ -72,8 +74,8 @@ func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) { min := (off) &^ (blockSize - 1) // round down max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up + // Read one block at a time. for ; min < max; min += blockSize { - // Read full block. m, err := h.File.ReadAt(h.block[:], min) if m != blockSize { return n, err @@ -98,20 +100,24 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) { min := (off) &^ (blockSize - 1) // round down max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up + // Write one block at a time. for ; min < max; min += blockSize { binary.LittleEndian.PutUint64(h.tweak[:], uint64(min)) data := h.block[:] if off > min || len(p[n:]) < blockSize { - // Read full block. + // Partial block write: read-update-write. m, err := h.File.ReadAt(h.block[:], min) if m != blockSize { if err != io.EOF { return n, err } - // Writing past the EOF. - // A partially written block is corrupt, - // and also considered to be past the EOF. + // Writing past the EOF: + // We're either appending an entirely new block, + // or the final block was only partially written. + // A partially written block can't be decripted, + // and is as good as corrupt. + // Either way, zero pad the file to the next block size. clear(data) } @@ -124,7 +130,6 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) { t := copy(data, p[n:]) h.hbsh.Encrypt(h.block[:], h.tweak[:]) - // Write full block. m, err := h.File.WriteAt(h.block[:], min) if m != blockSize { return n, err @@ -155,9 +160,11 @@ func (h *hbshFile) DeviceCharacteristics() vfs.DeviceCharacteristic { vfs.IOCAP_BATCH_ATOMIC) } -// This is needed for shared memory. -func (h *hbshFile) Unwrap() vfs.File { - return h.File +func (h *hbshFile) SharedMemory() vfs.SharedMemory { + if shm, ok := h.File.(vfs.FileSharedMemory); ok { + return shm.SharedMemory() + } + return nil } // Wrap optional methods. diff --git a/vfs/api.go b/vfs/api.go index e0484e7..f8e7195 100644 --- a/vfs/api.go +++ b/vfs/api.go @@ -1,7 +1,12 @@ // Package vfs wraps the C SQLite VFS API. package vfs -import "net/url" +import ( + "context" + "net/url" + + "github.com/tetratelabs/wazero/api" +) // A VFS defines the interface between the SQLite core and the underlying operating system. // @@ -129,3 +134,18 @@ type FileBatchAtomicWrite interface { CommitAtomicWrite() error RollbackAtomicWrite() error } + +// FileSharedMemory extends File to possibly implement shared memory. +// It's OK for SharedMemory to return nil. +type FileSharedMemory interface { + File + SharedMemory() SharedMemory +} + +// SharedMemory is a shared memory implementation. +// This cannot be externally implemented. +type SharedMemory interface { + shmMap(context.Context, api.Module, int32, int32, bool) (uint32, error) + shmLock(int32, int32, _ShmFlag) error + shmUnmap(bool) +} diff --git a/vfs/file.go b/vfs/file.go index af804d4..4f2aa39 100644 --- a/vfs/file.go +++ b/vfs/file.go @@ -1,7 +1,6 @@ package vfs import ( - "context" "errors" "io" "io/fs" @@ -12,7 +11,6 @@ import ( "syscall" "github.com/ncruces/go-sqlite3/util/osutil" - "github.com/tetratelabs/wazero/api" ) type vfsOS struct{} @@ -215,9 +213,3 @@ func (f *vfsFile) PowersafeOverwrite() bool { return f.psow } func (f *vfsFile) PersistentWAL() bool { return f.keepWAL } func (f *vfsFile) SetPowersafeOverwrite(psow bool) { f.psow = psow } func (f *vfsFile) SetPersistentWAL(keepWAL bool) { f.keepWAL = keepWAL } - -type fileShm interface { - shmMap(context.Context, api.Module, int32, int32, bool) (uint32, error) - shmLock(int32, int32, _ShmFlag) error - shmUnmap(bool) -} diff --git a/vfs/shm.go b/vfs/shm.go index 381498d..0b3c655 100644 --- a/vfs/shm.go +++ b/vfs/shm.go @@ -33,6 +33,8 @@ const ( _SHM_DMS = _SHM_BASE + _SHM_NLOCK ) +func (f *vfsFile) SharedMemory() SharedMemory { return f } + func (f *vfsFile) shmMap(ctx context.Context, mod api.Module, id, size int32, extend bool) (uint32, error) { // Ensure size is a multiple of the OS page size. if int(size)&(unix.Getpagesize()-1) != 0 { diff --git a/vfs/vfs.go b/vfs/vfs.go index 9f5e7b8..33a7ee1 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -171,8 +171,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla util.WriteUint32(mod, pOutFlags, uint32(flags)) } if pOutVFS != 0 && util.CanMapFiles(ctx) { - if _, ok := util.Unwrap(file).(fileShm); ok { - util.WriteUint32(mod, pOutVFS, 1) + if f, ok := file.(FileSharedMemory); ok { + if f.SharedMemory() != nil { + util.WriteUint32(mod, pOutVFS, 1) + } } } vfsFileRegister(ctx, mod, pFile, file) @@ -366,7 +368,7 @@ func vfsShmBarrier(ctx context.Context, mod api.Module, pFile uint32) { } func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szRegion int32, bExtend, pp uint32) _ErrorCode { - file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm) + file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() p, err := file.shmMap(ctx, mod, iRegion, szRegion, bExtend != 0) if err != nil { return vfsErrorCode(err, _IOERR_SHMMAP) @@ -376,13 +378,13 @@ func vfsShmMap(ctx context.Context, mod api.Module, pFile uint32, iRegion, szReg } func vfsShmLock(ctx context.Context, mod api.Module, pFile uint32, offset, n int32, flags _ShmFlag) _ErrorCode { - file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm) + file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() err := file.shmLock(offset, n, flags) return vfsErrorCode(err, _IOERR_SHMLOCK) } func vfsShmUnmap(ctx context.Context, mod api.Module, pFile, bDelete uint32) _ErrorCode { - file := util.Unwrap(vfsFileGet(ctx, mod, pFile)).(fileShm) + file := vfsFileGet(ctx, mod, pFile).(FileSharedMemory).SharedMemory() file.shmUnmap(bDelete != 0) return _OK }