Tests, fixes, docs.

This commit is contained in:
Nuno Cruces
2023-05-29 16:50:48 +01:00
parent f1c46db512
commit 8b2e96dedc
4 changed files with 178 additions and 24 deletions

View File

@@ -43,6 +43,9 @@ func (vfs MemoryVFS) FullPathname(name string) (string, error) {
const memSectorSize = 65536
// A MemoryDB is a [MemoryVFS] database.
//
// A MemoryDB is safe to access concurrently from multiple SQLite connections.
type MemoryDB struct {
mtx sync.RWMutex
size int64
@@ -60,10 +63,12 @@ type memoryFile struct {
readOnly bool
}
// Close implements the [File] and [io.Closer] interfaces.
func (m *memoryFile) Close() error {
return m.Unlock(LOCK_NONE)
}
// ReadAt implements the [File] and [io.ReaderAt] interfaces.
func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
@@ -71,31 +76,40 @@ func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) {
if off >= m.size {
return 0, io.EOF
}
base := off / memSectorSize
rest := off % memSectorSize
have := int64(memSectorSize)
if base == int64(len(m.data))-1 {
have = m.size % memSectorSize
have = modRoundUp(m.size, memSectorSize)
}
return copy(b, (*m.data[base])[rest:have]), nil
n = copy(b, (*m.data[base])[rest:have])
if n < len(b) {
// Assume reads are page aligned.
return 0, io.ErrNoProgress
}
return n, nil
}
// WriteAt implements the [File] and [io.WriterAt] interfaces.
func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) {
m.mtx.Lock()
defer m.mtx.Unlock()
base := off / memSectorSize
rest := off % memSectorSize
if base >= int64(len(m.data)) {
for base >= int64(len(m.data)) {
m.data = append(m.data, new([memSectorSize]byte))
}
n = copy((*m.data[base])[rest:], b)
if size := off + int64(n); size > m.size {
m.size = size
if n < len(b) {
// Assume writes are page aligned.
return 0, io.ErrShortWrite
}
return n, nil
}
// Truncate implements the [File] interface.
func (m *memoryFile) Truncate(size int64) error {
m.mtx.Lock()
defer m.mtx.Unlock()
@@ -106,31 +120,33 @@ func (m *memoryFile) truncate(size int64) error {
if size < m.size {
base := size / memSectorSize
rest := size % memSectorSize
clear((*m.data[base])[rest:])
if rest != 0 {
clear((*m.data[base])[rest:])
}
}
sectors := (size + memSectorSize - 1) / memSectorSize
sectors := divRoundUp(size, memSectorSize)
for sectors > int64(len(m.data)) {
m.data = append(m.data, new([memSectorSize]byte))
}
for sectors < int64(len(m.data)) {
last := int64(len(m.data)) - 1
m.data[last] = nil
m.data = m.data[:last]
}
clear(m.data[sectors:])
m.data = m.data[:sectors]
m.size = size
return nil
}
// Sync implements the [File] interface.
func (*memoryFile) Sync(flag SyncFlag) error {
return nil
}
// Size implements the [File] interface.
func (m *memoryFile) Size() (int64, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
return m.size, nil
}
// Lock implements the [File] interface.
func (m *memoryFile) Lock(lock LockLevel) error {
if m.lock >= lock {
return nil
@@ -185,6 +201,7 @@ func (m *memoryFile) Lock(lock LockLevel) error {
return nil
}
// Unlock implements the [File] interface.
func (m *memoryFile) Unlock(lock LockLevel) error {
if m.lock <= lock {
return nil
@@ -206,6 +223,7 @@ func (m *memoryFile) Unlock(lock LockLevel) error {
return nil
}
// CheckReservedLock implements the [File] interface.
func (m *memoryFile) CheckReservedLock() (bool, error) {
if m.lock >= LOCK_RESERVED {
return true, nil
@@ -215,10 +233,12 @@ func (m *memoryFile) CheckReservedLock() (bool, error) {
return m.reserved != nil, nil
}
// SectorSize implements the [File] interface.
func (*memoryFile) SectorSize() int {
return memSectorSize
}
// DeviceCharacteristics implements the [File] interface.
func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic {
return IOCAP_ATOMIC |
IOCAP_SEQUENTIAL |
@@ -226,6 +246,7 @@ func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic {
IOCAP_POWERSAFE_OVERWRITE
}
// SizeHint implements the [FileSizeHint] interface.
func (m *memoryFile) SizeHint(size int64) error {
m.mtx.Lock()
defer m.mtx.Unlock()
@@ -235,12 +256,22 @@ func (m *memoryFile) SizeHint(size int64) error {
return nil
}
// LockState implements the [FileLockState] interface.
func (m *memoryFile) LockState() LockLevel {
return m.lock
}
func clear(b []byte) {
func divRoundUp(a, b int64) int64 {
return (a + b - 1) / b
}
func modRoundUp(a, b int64) int64 {
return b - (b-a%b)%b
}
func clear[T any](b []T) {
var zero T
for i := range b {
b[i] = 0
b[i] = zero
}
}

View File

@@ -3,6 +3,7 @@ package sqlite3vfs_test
import (
"database/sql"
"fmt"
"io"
"log"
"os"
"path/filepath"
@@ -125,7 +126,23 @@ func TestReaderVFS_Open(t *testing.T) {
}
func TestNewSizeReaderAt(t *testing.T) {
n, err := sqlite3vfs.NewSizeReaderAt(strings.NewReader("abc")).Size()
f, err := os.Create(filepath.Join(t.TempDir(), "abc.txt"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
n, err := sqlite3vfs.NewSizeReaderAt(f).Size()
if err != nil {
t.Fatal(err)
}
if n != 0 {
t.Errorf("got %d", n)
}
reader := strings.NewReader("abc")
n, err = sqlite3vfs.NewSizeReaderAt(reader).Size()
if err != nil {
t.Fatal(err)
}
@@ -133,17 +150,55 @@ func TestNewSizeReaderAt(t *testing.T) {
t.Errorf("got %d", n)
}
f, err := os.Create(filepath.Join(t.TempDir(), "abc.txt"))
n, err = sqlite3vfs.NewSizeReaderAt(lener{reader, reader.Len()}).Size()
if err != nil {
t.Fatal(err)
}
defer f.Close()
n, err = sqlite3vfs.NewSizeReaderAt(f).Size()
if err != nil {
t.Fatal(err)
}
if n != 0 {
if n != 3 {
t.Errorf("got %d", n)
}
n, err = sqlite3vfs.NewSizeReaderAt(sizer{reader, reader.Size()}).Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
n, err = sqlite3vfs.NewSizeReaderAt(seeker{reader, reader}).Size()
if err != nil {
t.Fatal(err)
}
if n != 3 {
t.Errorf("got %d", n)
}
_, err = sqlite3vfs.NewSizeReaderAt(readerat{reader}).Size()
if err == nil {
t.Error("want error")
}
}
type lener struct {
io.ReaderAt
len int
}
func (l lener) Len() int { return l.len }
type sizer struct {
io.ReaderAt
size int64
}
func (l sizer) Size() (int64, error) { return l.size, nil }
type seeker struct {
io.ReaderAt
io.Seeker
}
type readerat struct {
io.ReaderAt
}

View File

@@ -32,6 +32,7 @@ var (
rt wazero.Runtime
module wazero.CompiledModule
instances atomic.Uint64
memory = sqlite3vfs.MemoryVFS{}
)
func init() {
@@ -51,6 +52,8 @@ func init() {
if err != nil {
panic(err)
}
sqlite3vfs.Register("memvfs", memory)
}
func config(ctx context.Context) wazero.ModuleConfig {
@@ -72,11 +75,28 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr)
buf = buf[:bytes.IndexByte(buf, 0)]
var memvfs, journal, timeout bool
args := strings.Split(string(buf), " ")
for i := range args {
args[i] = strings.Trim(args[i], `"`)
switch args[i] {
case "memvfs":
memvfs = true
case "--timeout":
timeout = true
case "--journalmode":
journal = true
}
}
args = args[:len(args)-1]
if memvfs {
if !timeout {
args = append(args, "--timeout", "1000")
}
if !journal {
args = append(args, "--journalmode", "memory")
}
}
cfg := config(ctx).WithArgs(args...)
go func() {
@@ -151,6 +171,42 @@ func Test_multiwrite01(t *testing.T) {
vfs.Close()
}
func Test_config01_memory(t *testing.T) {
memory["test.db"] = new(sqlite3vfs.MemoryDB)
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "test.db",
"config01.test",
"--vfs", "memvfs",
"--timeout", "1000",
"--journalmode", "memory")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func Test_multiwrite01_memory(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
memory["test.db"] = new(sqlite3vfs.MemoryDB)
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
cfg := config(ctx).WithArgs("mptest", "test.db",
"multiwrite01.test",
"--vfs", "memvfs",
"--timeout", "1000",
"--journalmode", "memory")
mod, err := rt.InstantiateModule(ctx, module, cfg)
if err != nil {
t.Error(err)
}
mod.Close(ctx)
vfs.Close()
}
func newContext(t *testing.T) context.Context {
return context.WithValue(context.Background(), logger{}, &testWriter{T: t})
}