diff --git a/sqlite3vfs/memory.go b/sqlite3vfs/memory.go index 9721921..697cc1f 100644 --- a/sqlite3vfs/memory.go +++ b/sqlite3vfs/memory.go @@ -47,16 +47,23 @@ const memSectorSize = 65536 // // A MemoryDB is safe to access concurrently through multiple SQLite connections. type MemoryDB struct { + // +checklocks:dataMtx MaxSize int64 - mtx sync.RWMutex - size int64 + // +checklocks:dataMtx data []*[memSectorSize]byte + // +checklocks:dataMtx + size int64 - locker sync.Mutex - pending *memoryFile + // +checklocks:lockMtx + pending *memoryFile + // +checklocks:lockMtx reserved *memoryFile - shared int + // +checklocks:lockMtx + shared int + + lockMtx sync.Mutex + dataMtx sync.RWMutex } // NewMemoryDB creates a new MemoryDB using mem as its initial contents. @@ -86,6 +93,12 @@ type memoryFile struct { readOnly bool } +var ( + // Ensure these interfaces are implemented: + _ FileLockState = &memoryFile{} + _ FileSizeHint = &memoryFile{} +) + // Close implements the [File] and [io.Closer] interfaces. func (m *memoryFile) Close() error { return m.Unlock(LOCK_NONE) @@ -93,8 +106,8 @@ func (m *memoryFile) Close() error { // ReadAt implements the [File] and [io.ReaderAt] interfaces. func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { - m.mtx.RLock() - defer m.mtx.RUnlock() + m.dataMtx.RLock() + defer m.dataMtx.RUnlock() if off >= m.size { return 0, io.EOF @@ -116,10 +129,11 @@ func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { // WriteAt implements the [File] and [io.WriterAt] interfaces. func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { - m.mtx.Lock() - defer m.mtx.Unlock() + m.dataMtx.Lock() + defer m.dataMtx.Unlock() - if m.MaxSize > 0 && off+int64(len(b)) > m.MaxSize { + size := off + int64(len(b)) + if m.MaxSize > 0 && size > m.MaxSize { return 0, _FULL } @@ -133,16 +147,20 @@ func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { // Assume writes are page aligned. return 0, io.ErrShortWrite } + if size > m.size { + m.size = size + } return n, nil } // Truncate implements the [File] interface. func (m *memoryFile) Truncate(size int64) error { - m.mtx.Lock() - defer m.mtx.Unlock() + m.dataMtx.Lock() + defer m.dataMtx.Unlock() return m.truncate(size) } +// +checklocks:m.dataMtx func (m *memoryFile) truncate(size int64) error { if m.MaxSize > 0 && size > m.MaxSize { return _FULL @@ -171,8 +189,8 @@ func (*memoryFile) Sync(flag SyncFlag) error { // Size implements the [File] interface. func (m *memoryFile) Size() (int64, error) { - m.mtx.RLock() - defer m.mtx.RUnlock() + m.dataMtx.RLock() + defer m.dataMtx.RUnlock() return m.size, nil } @@ -186,8 +204,8 @@ func (m *memoryFile) Lock(lock LockLevel) error { return _IOERR_LOCK } - m.locker.Lock() - defer m.locker.Unlock() + m.lockMtx.Lock() + defer m.lockMtx.Unlock() deadline := time.Now().Add(time.Millisecond) switch lock { @@ -196,9 +214,9 @@ func (m *memoryFile) Lock(lock LockLevel) error { if time.Now().After(deadline) { return _BUSY } - m.locker.Unlock() + m.lockMtx.Unlock() runtime.Gosched() - m.locker.Lock() + m.lockMtx.Lock() } m.shared++ @@ -221,9 +239,9 @@ func (m *memoryFile) Lock(lock LockLevel) error { if time.Now().After(deadline) { return _BUSY } - m.locker.Unlock() + m.lockMtx.Unlock() runtime.Gosched() - m.locker.Lock() + m.lockMtx.Lock() } } @@ -237,8 +255,8 @@ func (m *memoryFile) Unlock(lock LockLevel) error { return nil } - m.locker.Lock() - defer m.locker.Unlock() + m.lockMtx.Lock() + defer m.lockMtx.Unlock() if m.pending == m { m.pending = nil @@ -258,8 +276,8 @@ func (m *memoryFile) CheckReservedLock() (bool, error) { if m.lock >= LOCK_RESERVED { return true, nil } - m.locker.Lock() - defer m.locker.Unlock() + m.lockMtx.Lock() + defer m.lockMtx.Unlock() return m.reserved != nil, nil } @@ -278,8 +296,8 @@ func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic { // SizeHint implements the [FileSizeHint] interface. func (m *memoryFile) SizeHint(size int64) error { - m.mtx.Lock() - defer m.mtx.Unlock() + m.dataMtx.Lock() + defer m.dataMtx.Unlock() if size > m.size { return m.truncate(size) } diff --git a/sqlite3vfs/registry.go b/sqlite3vfs/registry.go index 3fb946f..ee1b7c6 100644 --- a/sqlite3vfs/registry.go +++ b/sqlite3vfs/registry.go @@ -3,6 +3,7 @@ package sqlite3vfs import "sync" var ( + // +checklocks:vfsRegistryMtx vfsRegistry map[string]VFS vfsRegistryMtx sync.Mutex ) diff --git a/sqlite3vfs/tests/mptest/mptest_test.go b/sqlite3vfs/tests/mptest/mptest_test.go index 59b8298..e264d12 100644 --- a/sqlite3vfs/tests/mptest/mptest_test.go +++ b/sqlite3vfs/tests/mptest/mptest_test.go @@ -214,7 +214,9 @@ func newContext(t *testing.T) context.Context { type logger struct{} type testWriter struct { + // +checklocks:mtx *testing.T + // +checklocks:mtx buf []byte mtx sync.Mutex }