From 7ca9d7942453944ff0c88a3a894cce9a0eb92708 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 26 May 2023 04:59:54 +0100 Subject: [PATCH] MemoryVFS. --- sqlite3vfs/memory.go | 165 ++++++++++++++++++++++++++++++++ sqlite3vfs/reader.go | 6 +- sqlite3vfs/vfs.go | 4 +- tests/db_test.go | 8 ++ tests/parallel/parallel_test.go | 59 +++++++++--- 5 files changed, 221 insertions(+), 21 deletions(-) create mode 100644 sqlite3vfs/memory.go diff --git a/sqlite3vfs/memory.go b/sqlite3vfs/memory.go new file mode 100644 index 0000000..f508e4e --- /dev/null +++ b/sqlite3vfs/memory.go @@ -0,0 +1,165 @@ +package sqlite3vfs + +import ( + "io" + "sync" +) + +// A MemoryVFS is a [VFS] for memory databases. +type MemoryVFS map[string]*MemoryDB + +var _ VFS = MemoryVFS{} + +// Open implements the [VFS] interface. +func (vfs MemoryVFS) Open(name string, flags OpenFlag) (File, OpenFlag, error) { + if flags&OPEN_MAIN_DB == 0 { + return nil, flags, _CANTOPEN + } + if db, ok := vfs[name]; ok { + return &memoryFile{ + MemoryDB: db, + readOnly: flags&OPEN_READONLY != 0, + }, flags, nil + } + return nil, flags, _CANTOPEN +} + +// Delete implements the [VFS] interface. +func (vfs MemoryVFS) Delete(name string, dirSync bool) error { + return _IOERR_DELETE +} + +// Access implements the [VFS] interface. +func (vfs MemoryVFS) Access(name string, flag AccessFlag) (bool, error) { + return false, nil +} + +// FullPathname implements the [VFS] interface. +func (vfs MemoryVFS) FullPathname(name string) (string, error) { + return name, nil +} + +const memSectorSize = 65536 + +type MemoryDB struct { + size int64 + data []*[memSectorSize]byte + mtx sync.Mutex +} + +type memoryFile struct { + *MemoryDB + locked bool + readOnly bool +} + +func (m *memoryFile) Close() error { + return m.Unlock(LOCK_NONE) +} + +func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { + if off >= m.size { + return 0, io.EOF + } + base := off / memSectorSize + rest := off % memSectorSize + have := int64(memSectorSize) + if base == int64(len(m.data))-1 { + have = m.size % memSectorSize + } + return copy(b, (*m.data[base])[rest:have]), nil +} + +func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { + base := off / memSectorSize + rest := off % memSectorSize + if base >= int64(len(m.data)) { + m.data = append(m.data, new([memSectorSize]byte)) + } + n = copy((*m.data[base])[rest:], b) + if size := off + int64(n); size > m.size { + m.size = size + } + return n, nil +} + +func (m *memoryFile) Truncate(size int64) error { + if size < m.size { + base := size / memSectorSize + rest := size % memSectorSize + clear((*m.data[base])[rest:]) + } + sectors := (size + memSectorSize - 1) / memSectorSize + for sectors > int64(len(m.data)) { + m.data = append(m.data, new([memSectorSize]byte)) + } + for sectors < int64(len(m.data)) { + last := int64(len(m.data)) - 1 + m.data[last] = nil + m.data = m.data[:last] + } + m.size = size + return nil +} + +func (*memoryFile) Sync(flag SyncFlag) error { + return nil +} + +func (m *memoryFile) Size() (int64, error) { + return m.size, nil +} + +func (m *memoryFile) Lock(lock LockLevel) error { + if m.readOnly && lock >= LOCK_RESERVED { + return _IOERR_LOCK + } + if m.locked || m.mtx.TryLock() { + m.locked = true + return nil + } + return _BUSY +} + +func (m *memoryFile) Unlock(lock LockLevel) error { + if m.locked && lock == LOCK_NONE { + m.locked = false + m.mtx.Unlock() + } + return nil +} + +func (m *memoryFile) CheckReservedLock() (bool, error) { + if m.locked { + return true, nil + } + if m.mtx.TryLock() { + m.mtx.Unlock() + return true, nil + } + return false, nil +} + +func (*memoryFile) SectorSize() int { + return memSectorSize +} + +func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic { + return IOCAP_ATOMIC | + IOCAP_SEQUENTIAL | + IOCAP_SAFE_APPEND | + IOCAP_POWERSAFE_OVERWRITE +} + +func (m *memoryFile) SizeHint(size int64) error { + if size > m.size { + return m.Truncate(size) + } + return nil +} + +func clear(b []byte) { + for i := range b { + b[i] = 0 + } +} diff --git a/sqlite3vfs/reader.go b/sqlite3vfs/reader.go index fb9be59..09d1a7a 100644 --- a/sqlite3vfs/reader.go +++ b/sqlite3vfs/reader.go @@ -5,7 +5,7 @@ import ( "io/fs" ) -// A ReaderVFS is [VFS] for immutable databases. +// A ReaderVFS is a [VFS] for immutable databases. type ReaderVFS map[string]SizeReaderAt var _ VFS = ReaderVFS{} @@ -64,11 +64,11 @@ func (readerFile) Sync(flag SyncFlag) error { return nil } -func (readerFile) Lock(elock LockLevel) error { +func (readerFile) Lock(lock LockLevel) error { return nil } -func (readerFile) Unlock(elock LockLevel) error { +func (readerFile) Unlock(lock LockLevel) error { return nil } diff --git a/sqlite3vfs/vfs.go b/sqlite3vfs/vfs.go index b568957..85d9fc3 100644 --- a/sqlite3vfs/vfs.go +++ b/sqlite3vfs/vfs.go @@ -222,9 +222,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs if n == 0 && err != io.EOF { return _IOERR_READ } - for i := range buf[n:] { - buf[n+i] = 0 - } + clear(buf[n:]) return _IOERR_SHORT_READ } diff --git a/tests/db_test.go b/tests/db_test.go index 1e8c22d..a511c94 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -6,6 +6,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/sqlite3vfs" ) func TestDB_memory(t *testing.T) { @@ -16,6 +17,13 @@ func TestDB_file(t *testing.T) { testDB(t, filepath.Join(t.TempDir(), "test.db")) } +func TestDB_VFS(t *testing.T) { + sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{ + "test.db": &sqlite3vfs.MemoryDB{}, + }) + testDB(t, "file:test.db?vfs=memvfs&_pragma=journal_mode(memory)") +} + func testDB(t *testing.T, name string) { t.Parallel() diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index df10f51..7ca2cfe 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -11,6 +11,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/sqlite3vfs" ) func TestParallel(t *testing.T) { @@ -21,7 +22,33 @@ func TestParallel(t *testing.T) { iter = 5000 } - name := filepath.Join(t.TempDir(), "test.db") + name := "file:" + + filepath.Join(t.TempDir(), "test.db") + + "?_pragma=busy_timeout(10000)" + + "&_pragma=locking_mode(normal)" + + "&_pragma=journal_mode(truncate)" + + "&_pragma=synchronous(off)" + testParallel(t, name, iter) + testIntegrity(t, name) +} + +func TestMemory(t *testing.T) { + var iter int + if testing.Short() { + iter = 1000 + } else { + iter = 5000 + } + + sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{ + "test.db": &sqlite3vfs.MemoryDB{}, + }) + + name := "file:test.db?vfs=memvfs" + + "&_pragma=busy_timeout(10000)" + + "&_pragma=locking_mode(normal)" + + "&_pragma=journal_mode(memory)" + + "&_pragma=synchronous(off)" testParallel(t, name, iter) testIntegrity(t, name) } @@ -31,8 +58,14 @@ func TestMultiProcess(t *testing.T) { t.Skip("skipping in short mode") } - name := filepath.Join(t.TempDir(), "test.db") - t.Setenv("TestMultiProcess_dbname", name) + file := filepath.Join(t.TempDir(), "test.db") + t.Setenv("TestMultiProcess_dbfile", file) + + name := "file:" + file + + "?_pragma=busy_timeout(10000)" + + "&_pragma=locking_mode(normal)" + + "&_pragma=journal_mode(truncate)" + + "&_pragma=synchronous(off)" cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess") out, err := cmd.StdoutPipe() @@ -57,11 +90,17 @@ func TestMultiProcess(t *testing.T) { } func TestChildProcess(t *testing.T) { - name := os.Getenv("TestMultiProcess_dbname") - if name == "" || testing.Short() { + file := os.Getenv("TestMultiProcess_dbfile") + if file == "" || testing.Short() { t.SkipNow() } + name := "file:" + file + + "?_pragma=busy_timeout(10000)" + + "&_pragma=locking_mode(normal)" + + "&_pragma=journal_mode(truncate)" + + "&_pragma=synchronous(off)" + testParallel(t, name, 1000) } @@ -73,16 +112,6 @@ func testParallel(t *testing.T, name string, n int) { } defer db.Close() - err = db.Exec(` - PRAGMA busy_timeout=10000; - PRAGMA synchronous=off; - PRAGMA locking_mode=normal; - PRAGMA journal_mode=truncate; - `) - if err != nil { - return err - } - err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) if err != nil { return err