diff --git a/README.md b/README.md index 9608fa7..6f170c3 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ Performance is tested by running - [ ] session extension - [ ] custom SQL functions - [ ] custom VFSes - - [ ] in-memory VFS + - [x] custom VFS API + - [x] in-memory VFS - [x] read-only VFS, wrapping an [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt) - [ ] cloud-based VFS, based on [Cloud Backed SQLite](https://sqlite.org/cloudsqlite/doc/trunk/www/index.wiki) - - [x] custom VFS API ### Alternatives diff --git a/sqlite3vfs/lock.go b/sqlite3vfs/lock.go index 496622c..e05580a 100644 --- a/sqlite3vfs/lock.go +++ b/sqlite3vfs/lock.go @@ -14,73 +14,73 @@ const ( _SHARED_SIZE = 510 ) -func (file *vfsFile) Lock(eLock LockLevel) error { +func (f *vfsFile) Lock(lock LockLevel) error { // Argument check. SQLite never explicitly requests a pending lock. - if eLock != LOCK_SHARED && eLock != LOCK_RESERVED && eLock != LOCK_EXCLUSIVE { + if lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE { panic(util.AssertErr()) } switch { - case file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE: + case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE: // Connection state check. panic(util.AssertErr()) - case file.lock == LOCK_NONE && eLock > LOCK_SHARED: + case f.lock == LOCK_NONE && lock > LOCK_SHARED: // We never move from unlocked to anything higher than a shared lock. panic(util.AssertErr()) - case file.lock != LOCK_SHARED && eLock == LOCK_RESERVED: + case f.lock != LOCK_SHARED && lock == LOCK_RESERVED: // A shared lock is always held when a reserved lock is requested. panic(util.AssertErr()) } // If we already have an equal or more restrictive lock, do nothing. - if file.lock >= eLock { + if f.lock >= lock { return nil } // Do not allow any kind of write-lock on a read-only database. - if file.readOnly && eLock >= LOCK_RESERVED { + if f.readOnly && lock >= LOCK_RESERVED { return _IOERR_LOCK } - switch eLock { + switch lock { case LOCK_SHARED: // Must be unlocked to get SHARED. - if file.lock != LOCK_NONE { + if f.lock != LOCK_NONE { panic(util.AssertErr()) } - if rc := osGetSharedLock(file.File, file.lockTimeout); rc != _OK { + if rc := osGetSharedLock(f.File, f.lockTimeout); rc != _OK { return rc } - file.lock = LOCK_SHARED + f.lock = LOCK_SHARED return nil case LOCK_RESERVED: // Must be SHARED to get RESERVED. - if file.lock != LOCK_SHARED { + if f.lock != LOCK_SHARED { panic(util.AssertErr()) } - if rc := osGetReservedLock(file.File, file.lockTimeout); rc != _OK { + if rc := osGetReservedLock(f.File, f.lockTimeout); rc != _OK { return rc } - file.lock = LOCK_RESERVED + f.lock = LOCK_RESERVED return nil case LOCK_EXCLUSIVE: // Must be SHARED, RESERVED or PENDING to get EXCLUSIVE. - if file.lock <= LOCK_NONE || file.lock >= LOCK_EXCLUSIVE { + if f.lock <= LOCK_NONE || f.lock >= LOCK_EXCLUSIVE { panic(util.AssertErr()) } // A PENDING lock is needed before acquiring an EXCLUSIVE lock. - if file.lock < LOCK_PENDING { - if rc := osGetPendingLock(file.File); rc != _OK { + if f.lock < LOCK_PENDING { + if rc := osGetPendingLock(f.File); rc != _OK { return rc } - file.lock = LOCK_PENDING + f.lock = LOCK_PENDING } - if rc := osGetExclusiveLock(file.File, file.lockTimeout); rc != _OK { + if rc := osGetExclusiveLock(f.File, f.lockTimeout); rc != _OK { return rc } - file.lock = LOCK_EXCLUSIVE + f.lock = LOCK_EXCLUSIVE return nil default: @@ -88,33 +88,33 @@ func (file *vfsFile) Lock(eLock LockLevel) error { } } -func (file *vfsFile) Unlock(eLock LockLevel) error { +func (f *vfsFile) Unlock(lock LockLevel) error { // Argument check. - if eLock != LOCK_NONE && eLock != LOCK_SHARED { + if lock != LOCK_NONE && lock != LOCK_SHARED { panic(util.AssertErr()) } // Connection state check. - if file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE { + if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE { panic(util.AssertErr()) } // If we don't have a more restrictive lock, do nothing. - if file.lock <= eLock { + if f.lock <= lock { return nil } - switch eLock { + switch lock { case LOCK_SHARED: - if rc := osDowngradeLock(file.File, file.lock); rc != _OK { + if rc := osDowngradeLock(f.File, f.lock); rc != _OK { return rc } - file.lock = LOCK_SHARED + f.lock = LOCK_SHARED return nil case LOCK_NONE: - rc := osReleaseLock(file.File, file.lock) - file.lock = LOCK_NONE + rc := osReleaseLock(f.File, f.lock) + f.lock = LOCK_NONE return rc default: @@ -122,16 +122,16 @@ func (file *vfsFile) Unlock(eLock LockLevel) error { } } -func (file *vfsFile) CheckReservedLock() (bool, error) { +func (f *vfsFile) CheckReservedLock() (bool, error) { // Connection state check. - if file.lock < LOCK_NONE || file.lock > LOCK_EXCLUSIVE { + if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE { panic(util.AssertErr()) } - if file.lock >= LOCK_RESERVED { + if f.lock >= LOCK_RESERVED { return true, nil } - return osCheckReservedLock(file.File) + return osCheckReservedLock(f.File) } func osGetReservedLock(file *os.File, timeout time.Duration) _ErrorCode { diff --git a/sqlite3vfs/memory.go b/sqlite3vfs/memory.go index f508e4e..b6e114c 100644 --- a/sqlite3vfs/memory.go +++ b/sqlite3vfs/memory.go @@ -2,7 +2,9 @@ package sqlite3vfs import ( "io" + "runtime" "sync" + "time" ) // A MemoryVFS is a [VFS] for memory databases. @@ -42,14 +44,19 @@ func (vfs MemoryVFS) FullPathname(name string) (string, error) { const memSectorSize = 65536 type MemoryDB struct { + mtx sync.RWMutex size int64 data []*[memSectorSize]byte - mtx sync.Mutex + + locker sync.Mutex + pending *memoryFile + reserved *memoryFile + shared int } type memoryFile struct { *MemoryDB - locked bool + lock LockLevel readOnly bool } @@ -58,6 +65,9 @@ func (m *memoryFile) Close() error { } func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { + m.mtx.RLock() + defer m.mtx.RUnlock() + if off >= m.size { return 0, io.EOF } @@ -71,6 +81,9 @@ func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { } func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { + m.mtx.Lock() + defer m.mtx.Unlock() + base := off / memSectorSize rest := off % memSectorSize if base >= int64(len(m.data)) { @@ -84,6 +97,12 @@ func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { } func (m *memoryFile) Truncate(size int64) error { + m.mtx.Lock() + defer m.mtx.Unlock() + return m.truncate(size) +} + +func (m *memoryFile) truncate(size int64) error { if size < m.size { base := size / memSectorSize rest := size % memSectorSize @@ -107,37 +126,93 @@ func (*memoryFile) Sync(flag SyncFlag) error { } func (m *memoryFile) Size() (int64, error) { + m.mtx.RLock() + defer m.mtx.RUnlock() return m.size, nil } func (m *memoryFile) Lock(lock LockLevel) error { + if m.lock >= lock { + return nil + } + if m.readOnly && lock >= LOCK_RESERVED { return _IOERR_LOCK } - if m.locked || m.mtx.TryLock() { - m.locked = true - return nil + + m.locker.Lock() + defer m.locker.Unlock() + deadline := time.Now().Add(time.Millisecond) + + switch lock { + case LOCK_SHARED: + for m.pending != nil { + if time.Now().After(deadline) { + return _BUSY + } + m.locker.Unlock() + runtime.Gosched() + m.locker.Lock() + } + m.shared++ + + case LOCK_RESERVED: + if m.reserved != nil { + return _BUSY + } + m.reserved = m + + case LOCK_EXCLUSIVE: + if m.lock < LOCK_PENDING { + if m.pending != nil { + return _BUSY + } + m.lock = LOCK_PENDING + m.pending = m + } + + for m.shared > 1 { + if time.Now().After(deadline) { + return _BUSY + } + m.locker.Unlock() + runtime.Gosched() + m.locker.Lock() + } } - return _BUSY + + m.lock = lock + return nil } func (m *memoryFile) Unlock(lock LockLevel) error { - if m.locked && lock == LOCK_NONE { - m.locked = false - m.mtx.Unlock() + if m.lock <= lock { + return nil } + + m.locker.Lock() + defer m.locker.Unlock() + + if m.pending == m { + m.pending = nil + } + if m.reserved == m { + m.reserved = nil + } + if lock < LOCK_SHARED { + m.shared-- + } + m.lock = lock return nil } func (m *memoryFile) CheckReservedLock() (bool, error) { - if m.locked { + if m.lock >= LOCK_RESERVED { return true, nil } - if m.mtx.TryLock() { - m.mtx.Unlock() - return true, nil - } - return false, nil + m.locker.Lock() + defer m.locker.Unlock() + return m.reserved != nil, nil } func (*memoryFile) SectorSize() int { @@ -152,12 +227,18 @@ func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic { } func (m *memoryFile) SizeHint(size int64) error { + m.mtx.Lock() + defer m.mtx.Unlock() if size > m.size { - return m.Truncate(size) + return m.truncate(size) } return nil } +func (m *memoryFile) LockState() LockLevel { + return m.lock +} + func clear(b []byte) { for i := range b { b[i] = 0