mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
MemoryVFS.
This commit is contained in:
165
sqlite3vfs/memory.go
Normal file
165
sqlite3vfs/memory.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package sqlite3vfs
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A MemoryVFS is a [VFS] for memory databases.
|
||||
type MemoryVFS map[string]*MemoryDB
|
||||
|
||||
var _ VFS = MemoryVFS{}
|
||||
|
||||
// Open implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
|
||||
if flags&OPEN_MAIN_DB == 0 {
|
||||
return nil, flags, _CANTOPEN
|
||||
}
|
||||
if db, ok := vfs[name]; ok {
|
||||
return &memoryFile{
|
||||
MemoryDB: db,
|
||||
readOnly: flags&OPEN_READONLY != 0,
|
||||
}, flags, nil
|
||||
}
|
||||
return nil, flags, _CANTOPEN
|
||||
}
|
||||
|
||||
// Delete implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) Delete(name string, dirSync bool) error {
|
||||
return _IOERR_DELETE
|
||||
}
|
||||
|
||||
// Access implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) Access(name string, flag AccessFlag) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// FullPathname implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) FullPathname(name string) (string, error) {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
const memSectorSize = 65536
|
||||
|
||||
type MemoryDB struct {
|
||||
size int64
|
||||
data []*[memSectorSize]byte
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
type memoryFile struct {
|
||||
*MemoryDB
|
||||
locked bool
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
func (m *memoryFile) Close() error {
|
||||
return m.Unlock(LOCK_NONE)
|
||||
}
|
||||
|
||||
func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) {
|
||||
if off >= m.size {
|
||||
return 0, io.EOF
|
||||
}
|
||||
base := off / memSectorSize
|
||||
rest := off % memSectorSize
|
||||
have := int64(memSectorSize)
|
||||
if base == int64(len(m.data))-1 {
|
||||
have = m.size % memSectorSize
|
||||
}
|
||||
return copy(b, (*m.data[base])[rest:have]), nil
|
||||
}
|
||||
|
||||
func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
base := off / memSectorSize
|
||||
rest := off % memSectorSize
|
||||
if base >= int64(len(m.data)) {
|
||||
m.data = append(m.data, new([memSectorSize]byte))
|
||||
}
|
||||
n = copy((*m.data[base])[rest:], b)
|
||||
if size := off + int64(n); size > m.size {
|
||||
m.size = size
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *memoryFile) Truncate(size int64) error {
|
||||
if size < m.size {
|
||||
base := size / memSectorSize
|
||||
rest := size % memSectorSize
|
||||
clear((*m.data[base])[rest:])
|
||||
}
|
||||
sectors := (size + memSectorSize - 1) / memSectorSize
|
||||
for sectors > int64(len(m.data)) {
|
||||
m.data = append(m.data, new([memSectorSize]byte))
|
||||
}
|
||||
for sectors < int64(len(m.data)) {
|
||||
last := int64(len(m.data)) - 1
|
||||
m.data[last] = nil
|
||||
m.data = m.data[:last]
|
||||
}
|
||||
m.size = size
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*memoryFile) Sync(flag SyncFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memoryFile) Size() (int64, error) {
|
||||
return m.size, nil
|
||||
}
|
||||
|
||||
func (m *memoryFile) Lock(lock LockLevel) error {
|
||||
if m.readOnly && lock >= LOCK_RESERVED {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
if m.locked || m.mtx.TryLock() {
|
||||
m.locked = true
|
||||
return nil
|
||||
}
|
||||
return _BUSY
|
||||
}
|
||||
|
||||
func (m *memoryFile) Unlock(lock LockLevel) error {
|
||||
if m.locked && lock == LOCK_NONE {
|
||||
m.locked = false
|
||||
m.mtx.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *memoryFile) CheckReservedLock() (bool, error) {
|
||||
if m.locked {
|
||||
return true, nil
|
||||
}
|
||||
if m.mtx.TryLock() {
|
||||
m.mtx.Unlock()
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (*memoryFile) SectorSize() int {
|
||||
return memSectorSize
|
||||
}
|
||||
|
||||
func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic {
|
||||
return IOCAP_ATOMIC |
|
||||
IOCAP_SEQUENTIAL |
|
||||
IOCAP_SAFE_APPEND |
|
||||
IOCAP_POWERSAFE_OVERWRITE
|
||||
}
|
||||
|
||||
func (m *memoryFile) SizeHint(size int64) error {
|
||||
if size > m.size {
|
||||
return m.Truncate(size)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func clear(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
// A ReaderVFS is [VFS] for immutable databases.
|
||||
// A ReaderVFS is a [VFS] for immutable databases.
|
||||
type ReaderVFS map[string]SizeReaderAt
|
||||
|
||||
var _ VFS = ReaderVFS{}
|
||||
@@ -64,11 +64,11 @@ func (readerFile) Sync(flag SyncFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (readerFile) Lock(elock LockLevel) error {
|
||||
func (readerFile) Lock(lock LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (readerFile) Unlock(elock LockLevel) error {
|
||||
func (readerFile) Unlock(lock LockLevel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -222,9 +222,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf, iAmt uint32, iOfs
|
||||
if n == 0 && err != io.EOF {
|
||||
return _IOERR_READ
|
||||
}
|
||||
for i := range buf[n:] {
|
||||
buf[n+i] = 0
|
||||
}
|
||||
clear(buf[n:])
|
||||
return _IOERR_SHORT_READ
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
)
|
||||
|
||||
func TestDB_memory(t *testing.T) {
|
||||
@@ -16,6 +17,13 @@ func TestDB_file(t *testing.T) {
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func TestDB_VFS(t *testing.T) {
|
||||
sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{
|
||||
"test.db": &sqlite3vfs.MemoryDB{},
|
||||
})
|
||||
testDB(t, "file:test.db?vfs=memvfs&_pragma=journal_mode(memory)")
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
)
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
@@ -21,7 +22,33 @@ func TestParallel(t *testing.T) {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
name := "file:" +
|
||||
filepath.Join(t.TempDir(), "test.db") +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
|
||||
func TestMemory(t *testing.T) {
|
||||
var iter int
|
||||
if testing.Short() {
|
||||
iter = 1000
|
||||
} else {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{
|
||||
"test.db": &sqlite3vfs.MemoryDB{},
|
||||
})
|
||||
|
||||
name := "file:test.db?vfs=memvfs" +
|
||||
"&_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(memory)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
testParallel(t, name, iter)
|
||||
testIntegrity(t, name)
|
||||
}
|
||||
@@ -31,8 +58,14 @@ func TestMultiProcess(t *testing.T) {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
name := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbname", name)
|
||||
file := filepath.Join(t.TempDir(), "test.db")
|
||||
t.Setenv("TestMultiProcess_dbfile", file)
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
cmd := exec.Command("go", "test", "-v", "-run", "TestChildProcess")
|
||||
out, err := cmd.StdoutPipe()
|
||||
@@ -57,11 +90,17 @@ func TestMultiProcess(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChildProcess(t *testing.T) {
|
||||
name := os.Getenv("TestMultiProcess_dbname")
|
||||
if name == "" || testing.Short() {
|
||||
file := os.Getenv("TestMultiProcess_dbfile")
|
||||
if file == "" || testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
name := "file:" + file +
|
||||
"?_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(truncate)" +
|
||||
"&_pragma=synchronous(off)"
|
||||
|
||||
testParallel(t, name, 1000)
|
||||
}
|
||||
|
||||
@@ -73,16 +112,6 @@ func testParallel(t *testing.T, name string, n int) {
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`
|
||||
PRAGMA busy_timeout=10000;
|
||||
PRAGMA synchronous=off;
|
||||
PRAGMA locking_mode=normal;
|
||||
PRAGMA journal_mode=truncate;
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user