diff --git a/module_test.go b/module_test.go index 9587dbe..b6eb28a 100644 --- a/module_test.go +++ b/module_test.go @@ -76,7 +76,6 @@ func TestConn_newArena(t *testing.T) { defer arena.free() const title = "Lorem ipsum" - ptr := arena.string(title) if ptr == 0 { t.Fatalf("got nullptr") @@ -93,6 +92,19 @@ func TestConn_newArena(t *testing.T) { if got := util.ReadString(m.mod, ptr, math.MaxUint32); got != body { t.Errorf("got %q, want %q", got, body) } + + ptr = arena.bytes(nil) + if ptr != 0 { + t.Errorf("want nullptr") + } + ptr = arena.bytes([]byte(title)) + if ptr == 0 { + t.Fatalf("got nullptr") + } + if got := util.View(m.mod, ptr, uint64(len(title))); string(got) != title { + t.Errorf("got %q, want %q", got, title) + } + arena.free() } diff --git a/sqlite3vfs/memory.go b/sqlite3vfs/memory.go index b6e114c..fcc50b6 100644 --- a/sqlite3vfs/memory.go +++ b/sqlite3vfs/memory.go @@ -43,6 +43,9 @@ func (vfs MemoryVFS) FullPathname(name string) (string, error) { const memSectorSize = 65536 +// A MemoryDB is a [MemoryVFS] database. +// +// A MemoryDB is safe to access concurrently from multiple SQLite connections. type MemoryDB struct { mtx sync.RWMutex size int64 @@ -60,10 +63,12 @@ type memoryFile struct { readOnly bool } +// Close implements the [File] and [io.Closer] interfaces. func (m *memoryFile) Close() error { return m.Unlock(LOCK_NONE) } +// ReadAt implements the [File] and [io.ReaderAt] interfaces. func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { m.mtx.RLock() defer m.mtx.RUnlock() @@ -71,31 +76,40 @@ func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { if off >= m.size { return 0, io.EOF } + base := off / memSectorSize rest := off % memSectorSize have := int64(memSectorSize) if base == int64(len(m.data))-1 { - have = m.size % memSectorSize + have = modRoundUp(m.size, memSectorSize) } - return copy(b, (*m.data[base])[rest:have]), nil + n = copy(b, (*m.data[base])[rest:have]) + if n < len(b) { + // Assume reads are page aligned. + return 0, io.ErrNoProgress + } + return n, nil } +// WriteAt implements the [File] and [io.WriterAt] interfaces. func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { m.mtx.Lock() defer m.mtx.Unlock() base := off / memSectorSize rest := off % memSectorSize - if base >= int64(len(m.data)) { + for base >= int64(len(m.data)) { m.data = append(m.data, new([memSectorSize]byte)) } n = copy((*m.data[base])[rest:], b) - if size := off + int64(n); size > m.size { - m.size = size + if n < len(b) { + // Assume writes are page aligned. + return 0, io.ErrShortWrite } return n, nil } +// Truncate implements the [File] interface. func (m *memoryFile) Truncate(size int64) error { m.mtx.Lock() defer m.mtx.Unlock() @@ -106,31 +120,33 @@ func (m *memoryFile) truncate(size int64) error { if size < m.size { base := size / memSectorSize rest := size % memSectorSize - clear((*m.data[base])[rest:]) + if rest != 0 { + clear((*m.data[base])[rest:]) + } } - sectors := (size + memSectorSize - 1) / memSectorSize + sectors := divRoundUp(size, memSectorSize) for sectors > int64(len(m.data)) { m.data = append(m.data, new([memSectorSize]byte)) } - for sectors < int64(len(m.data)) { - last := int64(len(m.data)) - 1 - m.data[last] = nil - m.data = m.data[:last] - } + clear(m.data[sectors:]) + m.data = m.data[:sectors] m.size = size return nil } +// Sync implements the [File] interface. func (*memoryFile) Sync(flag SyncFlag) error { return nil } +// Size implements the [File] interface. func (m *memoryFile) Size() (int64, error) { m.mtx.RLock() defer m.mtx.RUnlock() return m.size, nil } +// Lock implements the [File] interface. func (m *memoryFile) Lock(lock LockLevel) error { if m.lock >= lock { return nil @@ -185,6 +201,7 @@ func (m *memoryFile) Lock(lock LockLevel) error { return nil } +// Unlock implements the [File] interface. func (m *memoryFile) Unlock(lock LockLevel) error { if m.lock <= lock { return nil @@ -206,6 +223,7 @@ func (m *memoryFile) Unlock(lock LockLevel) error { return nil } +// CheckReservedLock implements the [File] interface. func (m *memoryFile) CheckReservedLock() (bool, error) { if m.lock >= LOCK_RESERVED { return true, nil @@ -215,10 +233,12 @@ func (m *memoryFile) CheckReservedLock() (bool, error) { return m.reserved != nil, nil } +// SectorSize implements the [File] interface. func (*memoryFile) SectorSize() int { return memSectorSize } +// DeviceCharacteristics implements the [File] interface. func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic { return IOCAP_ATOMIC | IOCAP_SEQUENTIAL | @@ -226,6 +246,7 @@ func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic { IOCAP_POWERSAFE_OVERWRITE } +// SizeHint implements the [FileSizeHint] interface. func (m *memoryFile) SizeHint(size int64) error { m.mtx.Lock() defer m.mtx.Unlock() @@ -235,12 +256,22 @@ func (m *memoryFile) SizeHint(size int64) error { return nil } +// LockState implements the [FileLockState] interface. func (m *memoryFile) LockState() LockLevel { return m.lock } -func clear(b []byte) { +func divRoundUp(a, b int64) int64 { + return (a + b - 1) / b +} + +func modRoundUp(a, b int64) int64 { + return b - (b-a%b)%b +} + +func clear[T any](b []T) { + var zero T for i := range b { - b[i] = 0 + b[i] = zero } } diff --git a/sqlite3vfs/reader_test.go b/sqlite3vfs/reader_test.go index 192d6aa..34e856f 100644 --- a/sqlite3vfs/reader_test.go +++ b/sqlite3vfs/reader_test.go @@ -3,6 +3,7 @@ package sqlite3vfs_test import ( "database/sql" "fmt" + "io" "log" "os" "path/filepath" @@ -125,7 +126,23 @@ func TestReaderVFS_Open(t *testing.T) { } func TestNewSizeReaderAt(t *testing.T) { - n, err := sqlite3vfs.NewSizeReaderAt(strings.NewReader("abc")).Size() + f, err := os.Create(filepath.Join(t.TempDir(), "abc.txt")) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + n, err := sqlite3vfs.NewSizeReaderAt(f).Size() + if err != nil { + t.Fatal(err) + } + if n != 0 { + t.Errorf("got %d", n) + } + + reader := strings.NewReader("abc") + + n, err = sqlite3vfs.NewSizeReaderAt(reader).Size() if err != nil { t.Fatal(err) } @@ -133,17 +150,55 @@ func TestNewSizeReaderAt(t *testing.T) { t.Errorf("got %d", n) } - f, err := os.Create(filepath.Join(t.TempDir(), "abc.txt")) + n, err = sqlite3vfs.NewSizeReaderAt(lener{reader, reader.Len()}).Size() if err != nil { t.Fatal(err) } - defer f.Close() - - n, err = sqlite3vfs.NewSizeReaderAt(f).Size() - if err != nil { - t.Fatal(err) - } - if n != 0 { + if n != 3 { t.Errorf("got %d", n) } + + n, err = sqlite3vfs.NewSizeReaderAt(sizer{reader, reader.Size()}).Size() + if err != nil { + t.Fatal(err) + } + if n != 3 { + t.Errorf("got %d", n) + } + + n, err = sqlite3vfs.NewSizeReaderAt(seeker{reader, reader}).Size() + if err != nil { + t.Fatal(err) + } + if n != 3 { + t.Errorf("got %d", n) + } + + _, err = sqlite3vfs.NewSizeReaderAt(readerat{reader}).Size() + if err == nil { + t.Error("want error") + } +} + +type lener struct { + io.ReaderAt + len int +} + +func (l lener) Len() int { return l.len } + +type sizer struct { + io.ReaderAt + size int64 +} + +func (l sizer) Size() (int64, error) { return l.size, nil } + +type seeker struct { + io.ReaderAt + io.Seeker +} + +type readerat struct { + io.ReaderAt } diff --git a/sqlite3vfs/tests/mptest/mptest_test.go b/sqlite3vfs/tests/mptest/mptest_test.go index 1dd088d..59b8298 100644 --- a/sqlite3vfs/tests/mptest/mptest_test.go +++ b/sqlite3vfs/tests/mptest/mptest_test.go @@ -32,6 +32,7 @@ var ( rt wazero.Runtime module wazero.CompiledModule instances atomic.Uint64 + memory = sqlite3vfs.MemoryVFS{} ) func init() { @@ -51,6 +52,8 @@ func init() { if err != nil { panic(err) } + + sqlite3vfs.Register("memvfs", memory) } func config(ctx context.Context) wazero.ModuleConfig { @@ -72,11 +75,28 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 { buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr) buf = buf[:bytes.IndexByte(buf, 0)] + var memvfs, journal, timeout bool args := strings.Split(string(buf), " ") for i := range args { args[i] = strings.Trim(args[i], `"`) + switch args[i] { + case "memvfs": + memvfs = true + case "--timeout": + timeout = true + case "--journalmode": + journal = true + } } args = args[:len(args)-1] + if memvfs { + if !timeout { + args = append(args, "--timeout", "1000") + } + if !journal { + args = append(args, "--journalmode", "memory") + } + } cfg := config(ctx).WithArgs(args...) go func() { @@ -151,6 +171,42 @@ func Test_multiwrite01(t *testing.T) { vfs.Close() } +func Test_config01_memory(t *testing.T) { + memory["test.db"] = new(sqlite3vfs.MemoryDB) + ctx, vfs := sqlite3vfs.NewContext(newContext(t)) + cfg := config(ctx).WithArgs("mptest", "test.db", + "config01.test", + "--vfs", "memvfs", + "--timeout", "1000", + "--journalmode", "memory") + mod, err := rt.InstantiateModule(ctx, module, cfg) + if err != nil { + t.Error(err) + } + mod.Close(ctx) + vfs.Close() +} + +func Test_multiwrite01_memory(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + memory["test.db"] = new(sqlite3vfs.MemoryDB) + ctx, vfs := sqlite3vfs.NewContext(newContext(t)) + cfg := config(ctx).WithArgs("mptest", "test.db", + "multiwrite01.test", + "--vfs", "memvfs", + "--timeout", "1000", + "--journalmode", "memory") + mod, err := rt.InstantiateModule(ctx, module, cfg) + if err != nil { + t.Error(err) + } + mod.Close(ctx) + vfs.Close() +} + func newContext(t *testing.T) context.Context { return context.WithValue(context.Background(), logger{}, &testWriter{T: t}) }