MemoryVFS.

This commit is contained in:
Nuno Cruces
2023-05-26 04:59:54 +01:00
parent 254d473546
commit 7ca9d79424
5 changed files with 221 additions and 21 deletions

165
sqlite3vfs/memory.go Normal file
View File

@@ -0,0 +1,165 @@
package sqlite3vfs
import (
"io"
"sync"
)
// A MemoryVFS is a [VFS] for memory databases.
type MemoryVFS map[string]*MemoryDB
var _ VFS = MemoryVFS{}
// Open implements the [VFS] interface.
func (vfs MemoryVFS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
if flags&OPEN_MAIN_DB == 0 {
return nil, flags, _CANTOPEN
}
if db, ok := vfs[name]; ok {
return &memoryFile{
MemoryDB: db,
readOnly: flags&OPEN_READONLY != 0,
}, flags, nil
}
return nil, flags, _CANTOPEN
}
// Delete implements the [VFS] interface.
func (vfs MemoryVFS) Delete(name string, dirSync bool) error {
return _IOERR_DELETE
}
// Access implements the [VFS] interface.
func (vfs MemoryVFS) Access(name string, flag AccessFlag) (bool, error) {
return false, nil
}
// FullPathname implements the [VFS] interface.
func (vfs MemoryVFS) FullPathname(name string) (string, error) {
return name, nil
}
const memSectorSize = 65536
type MemoryDB struct {
size int64
data []*[memSectorSize]byte
mtx sync.Mutex
}
type memoryFile struct {
*MemoryDB
locked bool
readOnly bool
}
func (m *memoryFile) Close() error {
return m.Unlock(LOCK_NONE)
}
func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) {
if off >= m.size {
return 0, io.EOF
}
base := off / memSectorSize
rest := off % memSectorSize
have := int64(memSectorSize)
if base == int64(len(m.data))-1 {
have = m.size % memSectorSize
}
return copy(b, (*m.data[base])[rest:have]), nil
}
func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) {
base := off / memSectorSize
rest := off % memSectorSize
if base >= int64(len(m.data)) {
m.data = append(m.data, new([memSectorSize]byte))
}
n = copy((*m.data[base])[rest:], b)
if size := off + int64(n); size > m.size {
m.size = size
}
return n, nil
}
func (m *memoryFile) Truncate(size int64) error {
if size < m.size {
base := size / memSectorSize
rest := size % memSectorSize
clear((*m.data[base])[rest:])
}
sectors := (size + memSectorSize - 1) / memSectorSize
for sectors > int64(len(m.data)) {
m.data = append(m.data, new([memSectorSize]byte))
}
for sectors < int64(len(m.data)) {
last := int64(len(m.data)) - 1
m.data[last] = nil
m.data = m.data[:last]
}
m.size = size
return nil
}
func (*memoryFile) Sync(flag SyncFlag) error {
return nil
}
func (m *memoryFile) Size() (int64, error) {
return m.size, nil
}
func (m *memoryFile) Lock(lock LockLevel) error {
if m.readOnly && lock >= LOCK_RESERVED {
return _IOERR_LOCK
}
if m.locked || m.mtx.TryLock() {
m.locked = true
return nil
}
return _BUSY
}
func (m *memoryFile) Unlock(lock LockLevel) error {
if m.locked && lock == LOCK_NONE {
m.locked = false
m.mtx.Unlock()
}
return nil
}
func (m *memoryFile) CheckReservedLock() (bool, error) {
if m.locked {
return true, nil
}
if m.mtx.TryLock() {
m.mtx.Unlock()
return true, nil
}
return false, nil
}
func (*memoryFile) SectorSize() int {
return memSectorSize
}
func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic {
return IOCAP_ATOMIC |
IOCAP_SEQUENTIAL |
IOCAP_SAFE_APPEND |
IOCAP_POWERSAFE_OVERWRITE
}
func (m *memoryFile) SizeHint(size int64) error {
if size > m.size {
return m.Truncate(size)
}
return nil
}
func clear(b []byte) {
for i := range b {
b[i] = 0
}
}

View File

@@ -5,7 +5,7 @@ import (
"io/fs"
)
// A ReaderVFS is [VFS] for immutable databases.
// A ReaderVFS is a [VFS] for immutable databases.
type ReaderVFS map[string]SizeReaderAt
var _ VFS = ReaderVFS{}
@@ -64,11 +64,11 @@ func (readerFile) Sync(flag SyncFlag) error {
return nil
}
func (readerFile) Lock(elock LockLevel) error {
func (readerFile) Lock(lock LockLevel) error {
return nil
}
func (readerFile) Unlock(elock LockLevel) error {
func (readerFile) Unlock(lock LockLevel) error {
return nil
}

View File

@@ -222,9 +222,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
if n == 0 && err != io.EOF {
return _IOERR_READ
}
for i := range buf[n:] {
buf[n+i] = 0
}
clear(buf[n:])
return _IOERR_SHORT_READ
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
)
func TestDB_memory(t *testing.T) {
@@ -16,6 +17,13 @@ func TestDB_file(t *testing.T) {
testDB(t, filepath.Join(t.TempDir(), "test.db"))
}
func TestDB_VFS(t *testing.T) {
sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{
"test.db": &sqlite3vfs.MemoryDB{},
})
testDB(t, "file:test.db?vfs=memvfs&_pragma=journal_mode(memory)")
}
func testDB(t *testing.T, name string) {
t.Parallel()

View File

@@ -11,6 +11,7 @@ import (
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/sqlite3vfs"
)
func TestParallel(t *testing.T) {
@@ -21,7 +22,33 @@ func TestParallel(t *testing.T) {
iter = 5000
}
name := filepath.Join(t.TempDir(), "test.db")
name := "file:" +
filepath.Join(t.TempDir(), "test.db") +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
testIntegrity(t, name)
}
func TestMemory(t *testing.T) {
var iter int
if testing.Short() {
iter = 1000
} else {
iter = 5000
}
sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{
"test.db": &sqlite3vfs.MemoryDB{},
})
name := "file:test.db?vfs=memvfs" +
"&_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(memory)" +
"&_pragma=synchronous(off)"
testParallel(t, name, iter)
testIntegrity(t, name)
}
@@ -31,8 +58,14 @@ func TestMultiProcess(t *testing.T) {
t.Skip("skipping in short mode")
}
name := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbname", name)
file := filepath.Join(t.TempDir(), "test.db")
t.Setenv("TestMultiProcess_dbfile", file)
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
out, err := cmd.StdoutPipe()
@@ -57,11 +90,17 @@ func TestMultiProcess(t *testing.T) {
}
func TestChildProcess(t *testing.T) {
name := os.Getenv("TestMultiProcess_dbname")
if name == "" || testing.Short() {
file := os.Getenv("TestMultiProcess_dbfile")
if file == "" || testing.Short() {
t.SkipNow()
}
name := "file:" + file +
"?_pragma=busy_timeout(10000)" +
"&_pragma=locking_mode(normal)" +
"&_pragma=journal_mode(truncate)" +
"&_pragma=synchronous(off)"
testParallel(t, name, 1000)
}
@@ -73,16 +112,6 @@ func testParallel(t *testing.T, name string, n int) {
}
defer db.Close()
err = db.Exec(`
PRAGMA busy_timeout=10000;
PRAGMA synchronous=off;
PRAGMA locking_mode=normal;
PRAGMA journal_mode=truncate;
`)
if err != nil {
return err
}
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
return err