diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index c285b78..ce6115d 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -61,6 +61,8 @@ func Test_memdb(t *testing.T) { iter = 5000 } + memdb.Delete("test.db") + memdb.Create("test.db", nil) name := "file:/test.db?vfs=memdb" testParallel(t, name, iter) testIntegrity(t, name) @@ -145,6 +147,7 @@ func Benchmark_memdb(b *testing.B) { b.ResetTimer() memdb.Delete("test.db") + memdb.Create("test.db", nil) name := "file:/test.db?vfs=memdb" testParallel(b, name, b.N) } diff --git a/vfs/memdb/api.go b/vfs/memdb/api.go index c32cf1a..5a2b84c 100644 --- a/vfs/memdb/api.go +++ b/vfs/memdb/api.go @@ -33,8 +33,11 @@ func Create(name string, data []byte) { memoryMtx.Lock() defer memoryMtx.Unlock() - db := new(memDB) - db.size = int64(len(data)) + db := &memDB{ + refs: 1, + name: name, + size: int64(len(data)), + } // Convert data from WAL to rollback journal. if len(data) >= 20 && data[18] == 2 && data[19] == 2 { diff --git a/vfs/memdb/memdb.go b/vfs/memdb/memdb.go index 09ffa4e..8dc57ab 100644 --- a/vfs/memdb/memdb.go +++ b/vfs/memdb/memdb.go @@ -3,7 +3,6 @@ package memdb import ( "io" "runtime" - "strings" "sync" "time" @@ -34,22 +33,25 @@ func (memVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, err return nil, flags, sqlite3.CANTOPEN } - var db *memDB + // A shared database has a name that begins with "/". + shared := len(name) > 1 && name[0] == '/' - shared := strings.HasPrefix(name, "/") + var db *memDB if shared { + name = name[1:] memoryMtx.Lock() defer memoryMtx.Unlock() - db = memoryDBs[name[1:]] + db = memoryDBs[name] } if db == nil { if flags&vfs.OPEN_CREATE == 0 { return nil, flags, sqlite3.CANTOPEN } - db = new(memDB) + db = &memDB{name: name} } if shared { - memoryDBs[name[1:]] = db // +checklocksignore: lock is held + db.refs++ // +checklocksforce: memoryMtx is held + memoryDBs[name] = db } return &memFile{ @@ -71,6 +73,8 @@ func (memVFS) FullPathname(name string) (string, error) { } type memDB struct { + name string + // +checklocks:lockMtx pending *memFile // +checklocks:lockMtx @@ -85,10 +89,21 @@ type memDB struct { // +checklocks:lockMtx shared int + // +checklocks:memoryMtx + refs int + lockMtx sync.Mutex dataMtx sync.RWMutex } +func (m *memDB) release() { + memoryMtx.Lock() + defer memoryMtx.Unlock() + if m.refs--; m.refs == 0 && m == memoryDBs[m.name] { + delete(memoryDBs, m.name) + } +} + type memFile struct { *memDB lock vfs.LockLevel @@ -102,6 +117,7 @@ var ( ) func (m *memFile) Close() error { + m.release() return m.Unlock(vfs.LOCK_NONE) }