mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Refactor memdb API.
This commit is contained in:
@@ -83,7 +83,7 @@ type module struct {
|
||||
}
|
||||
|
||||
func newModule(mod api.Module) (m *module, err error) {
|
||||
m = &module{}
|
||||
m = new(module)
|
||||
m.mod = mod
|
||||
m.ctx, m.vfs = sqlite3vfs.NewContext(context.Background())
|
||||
|
||||
|
||||
58
sqlite3memdb/api.go
Normal file
58
sqlite3memdb/api.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// Package sqlite3memdb implements the "memdb" SQLite VFS.
|
||||
//
|
||||
// The "memdb" [sqlite3vfs.VFS] allows the same in-memory database to be shared
|
||||
// among multiple database connections in the same process,
|
||||
// as long as the database name begins with "/".
|
||||
//
|
||||
// Importing package sqlite3memdb registers the VFS.
|
||||
//
|
||||
// import _ "github.com/ncruces/go-sqlite3/sqlite3memdb"
|
||||
package sqlite3memdb
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sqlite3vfs.Register("memdb", vfs{})
|
||||
}
|
||||
|
||||
var (
|
||||
memoryMtx sync.Mutex
|
||||
memoryDBs = map[string]*dbase{}
|
||||
)
|
||||
|
||||
// Create creates a shared memory database,
|
||||
// using data as its initial contents.
|
||||
// The new database takes ownership of data,
|
||||
// and the caller should not use data after this call.
|
||||
func Create(name string, data []byte) {
|
||||
memoryMtx.Lock()
|
||||
defer memoryMtx.Unlock()
|
||||
|
||||
db := new(dbase)
|
||||
db.size = int64(len(data))
|
||||
|
||||
sectors := divRoundUp(db.size, sectorSize)
|
||||
db.data = make([]*[sectorSize]byte, sectors)
|
||||
for i := range db.data {
|
||||
sector := data[i*sectorSize:]
|
||||
if len(sector) >= sectorSize {
|
||||
db.data[i] = (*[sectorSize]byte)(sector)
|
||||
} else {
|
||||
db.data[i] = new([sectorSize]byte)
|
||||
copy((*db.data[i])[:], sector)
|
||||
}
|
||||
}
|
||||
|
||||
memoryDBs[name] = db
|
||||
}
|
||||
|
||||
// Delete deletes a shared memory database.
|
||||
func Delete(name string) {
|
||||
memoryMtx.Lock()
|
||||
defer memoryMtx.Unlock()
|
||||
delete(memoryDBs, name)
|
||||
}
|
||||
51
sqlite3memdb/example_test.go
Normal file
51
sqlite3memdb/example_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package sqlite3memdb_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "embed"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3memdb"
|
||||
)
|
||||
|
||||
//go:embed testdata/test.db
|
||||
var testDB []byte
|
||||
|
||||
func Example() {
|
||||
sqlite3memdb.Create("test.db", testDB)
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:/test.db?vfs=memdb")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s %s\n", id, name)
|
||||
}
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 2 whatever
|
||||
// 3 rust
|
||||
}
|
||||
292
sqlite3memdb/memdb.go
Normal file
292
sqlite3memdb/memdb.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package sqlite3memdb
|
||||
|
||||
import (
|
||||
"io"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
)
|
||||
|
||||
type vfs struct{}
|
||||
|
||||
func (vfs) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, sqlite3vfs.OpenFlag, error) {
|
||||
if flags&sqlite3vfs.OPEN_MAIN_DB == 0 {
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
|
||||
var db *dbase
|
||||
|
||||
shared := strings.HasPrefix(name, "/")
|
||||
if shared {
|
||||
memoryMtx.Lock()
|
||||
defer memoryMtx.Unlock()
|
||||
db = memoryDBs[name[1:]]
|
||||
}
|
||||
if db == nil {
|
||||
if flags&sqlite3vfs.OPEN_CREATE == 0 {
|
||||
return nil, flags, sqlite3.CANTOPEN
|
||||
}
|
||||
db = new(dbase)
|
||||
}
|
||||
if shared {
|
||||
memoryDBs[name[1:]] = db
|
||||
}
|
||||
|
||||
return &file{
|
||||
dbase: db,
|
||||
readOnly: flags&sqlite3vfs.OPEN_READONLY != 0,
|
||||
}, flags | sqlite3vfs.OPEN_MEMORY, nil
|
||||
}
|
||||
|
||||
func (vfs) Delete(name string, dirSync bool) error {
|
||||
return sqlite3.IOERR_DELETE
|
||||
}
|
||||
|
||||
func (vfs) Access(name string, flag sqlite3vfs.AccessFlag) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (vfs) FullPathname(name string) (string, error) {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
const sectorSize = 65536
|
||||
|
||||
type dbase struct {
|
||||
// +checklocks:lockMtx
|
||||
pending *file
|
||||
// +checklocks:lockMtx
|
||||
reserved *file
|
||||
|
||||
// +checklocks:dataMtx
|
||||
data []*[sectorSize]byte
|
||||
|
||||
// +checklocks:dataMtx
|
||||
size int64
|
||||
|
||||
// +checklocks:lockMtx
|
||||
shared int
|
||||
|
||||
lockMtx sync.Mutex
|
||||
dataMtx sync.RWMutex
|
||||
}
|
||||
|
||||
type file struct {
|
||||
*dbase
|
||||
lock sqlite3vfs.LockLevel
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ sqlite3vfs.FileLockState = &file{}
|
||||
_ sqlite3vfs.FileSizeHint = &file{}
|
||||
)
|
||||
|
||||
func (m *file) Close() error {
|
||||
return m.Unlock(sqlite3vfs.LOCK_NONE)
|
||||
}
|
||||
|
||||
func (m *file) ReadAt(b []byte, off int64) (n int, err error) {
|
||||
m.dataMtx.RLock()
|
||||
defer m.dataMtx.RUnlock()
|
||||
|
||||
if off >= m.size {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
base := off / sectorSize
|
||||
rest := off % sectorSize
|
||||
have := int64(sectorSize)
|
||||
if base == int64(len(m.data))-1 {
|
||||
have = modRoundUp(m.size, sectorSize)
|
||||
}
|
||||
n = copy(b, (*m.data[base])[rest:have])
|
||||
if n < len(b) {
|
||||
// Assume reads are page aligned.
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *file) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
|
||||
base := off / sectorSize
|
||||
rest := off % sectorSize
|
||||
for base >= int64(len(m.data)) {
|
||||
m.data = append(m.data, new([sectorSize]byte))
|
||||
}
|
||||
n = copy((*m.data[base])[rest:], b)
|
||||
if n < len(b) {
|
||||
// Assume writes are page aligned.
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
if size := off + int64(len(b)); size > m.size {
|
||||
m.size = size
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *file) Truncate(size int64) error {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
return m.truncate(size)
|
||||
}
|
||||
|
||||
// +checklocks:m.dataMtx
|
||||
func (m *file) truncate(size int64) error {
|
||||
if size < m.size {
|
||||
base := size / sectorSize
|
||||
rest := size % sectorSize
|
||||
if rest != 0 {
|
||||
clear((*m.data[base])[rest:])
|
||||
}
|
||||
}
|
||||
sectors := divRoundUp(size, sectorSize)
|
||||
for sectors > int64(len(m.data)) {
|
||||
m.data = append(m.data, new([sectorSize]byte))
|
||||
}
|
||||
clear(m.data[sectors:])
|
||||
m.data = m.data[:sectors]
|
||||
m.size = size
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*file) Sync(flag sqlite3vfs.SyncFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *file) Size() (int64, error) {
|
||||
m.dataMtx.RLock()
|
||||
defer m.dataMtx.RUnlock()
|
||||
return m.size, nil
|
||||
}
|
||||
|
||||
func (m *file) Lock(lock sqlite3vfs.LockLevel) error {
|
||||
if m.lock >= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.readOnly && lock >= sqlite3vfs.LOCK_RESERVED {
|
||||
return sqlite3.IOERR_LOCK
|
||||
}
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
deadline := time.Now().Add(time.Millisecond)
|
||||
|
||||
switch lock {
|
||||
case sqlite3vfs.LOCK_SHARED:
|
||||
for m.pending != nil {
|
||||
if time.Now().After(deadline) {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
}
|
||||
m.shared++
|
||||
|
||||
case sqlite3vfs.LOCK_RESERVED:
|
||||
if m.reserved != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.reserved = m
|
||||
|
||||
case sqlite3vfs.LOCK_EXCLUSIVE:
|
||||
if m.lock < sqlite3vfs.LOCK_PENDING {
|
||||
if m.pending != nil {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lock = sqlite3vfs.LOCK_PENDING
|
||||
m.pending = m
|
||||
}
|
||||
|
||||
for m.shared > 1 {
|
||||
if time.Now().After(deadline) {
|
||||
return sqlite3.BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
m.lock = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *file) Unlock(lock sqlite3vfs.LockLevel) error {
|
||||
if m.lock <= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
|
||||
if m.pending == m {
|
||||
m.pending = nil
|
||||
}
|
||||
if m.reserved == m {
|
||||
m.reserved = nil
|
||||
}
|
||||
if lock < sqlite3vfs.LOCK_SHARED {
|
||||
m.shared--
|
||||
}
|
||||
m.lock = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *file) CheckReservedLock() (bool, error) {
|
||||
if m.lock >= sqlite3vfs.LOCK_RESERVED {
|
||||
return true, nil
|
||||
}
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
return m.reserved != nil, nil
|
||||
}
|
||||
|
||||
func (*file) SectorSize() int {
|
||||
return sectorSize
|
||||
}
|
||||
|
||||
func (*file) DeviceCharacteristics() sqlite3vfs.DeviceCharacteristic {
|
||||
return sqlite3vfs.IOCAP_ATOMIC |
|
||||
sqlite3vfs.IOCAP_SEQUENTIAL |
|
||||
sqlite3vfs.IOCAP_SAFE_APPEND |
|
||||
sqlite3vfs.IOCAP_POWERSAFE_OVERWRITE
|
||||
}
|
||||
|
||||
func (m *file) SizeHint(size int64) error {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
if size > m.size {
|
||||
return m.truncate(size)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *file) LockState() sqlite3vfs.LockLevel {
|
||||
return m.lock
|
||||
}
|
||||
|
||||
func divRoundUp(a, b int64) int64 {
|
||||
return (a + b - 1) / b
|
||||
}
|
||||
|
||||
func modRoundUp(a, b int64) int64 {
|
||||
return b - (b-a%b)%b
|
||||
}
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
@@ -22,7 +22,6 @@ const (
|
||||
_READONLY _ErrorCode = util.READONLY
|
||||
_IOERR _ErrorCode = util.IOERR
|
||||
_NOTFOUND _ErrorCode = util.NOTFOUND
|
||||
_FULL _ErrorCode = util.FULL
|
||||
_CANTOPEN _ErrorCode = util.CANTOPEN
|
||||
_IOERR_READ _ErrorCode = util.IOERR_READ
|
||||
_IOERR_SHORT_READ _ErrorCode = util.IOERR_SHORT_READ
|
||||
|
||||
@@ -17,43 +17,6 @@ import (
|
||||
//go:embed testdata/test.db
|
||||
var testDB []byte
|
||||
|
||||
func ExampleMemoryVFS_embed() {
|
||||
sqlite3vfs.Register("memory", sqlite3vfs.MemoryVFS{
|
||||
"test.db": sqlite3vfs.NewMemoryDB(testDB),
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3", "file:test.db?vfs=memory")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(`SELECT id, name FROM users`)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, name string
|
||||
err = rows.Scan(&id, &name)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("%s %s\n", id, name)
|
||||
}
|
||||
// Output:
|
||||
// 0 go
|
||||
// 1 zig
|
||||
// 2 whatever
|
||||
// 3 rust
|
||||
}
|
||||
|
||||
func ExampleReaderVFS_http() {
|
||||
sqlite3vfs.Register("httpvfs", sqlite3vfs.ReaderVFS{
|
||||
"demo.db": httpreadat.New("https://www.sanford.io/demo.db"),
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
package sqlite3vfs
|
||||
|
||||
import (
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A MemoryVFS is a [VFS] for memory databases.
|
||||
type MemoryVFS map[string]*MemoryDB
|
||||
|
||||
var _ VFS = MemoryVFS{}
|
||||
|
||||
// Open implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
|
||||
if flags&OPEN_MAIN_DB == 0 {
|
||||
return nil, flags, _CANTOPEN
|
||||
}
|
||||
if db, ok := vfs[name]; ok {
|
||||
return &memoryFile{
|
||||
MemoryDB: db,
|
||||
readOnly: flags&OPEN_READONLY != 0,
|
||||
}, flags | OPEN_MEMORY, nil
|
||||
}
|
||||
return nil, flags, _CANTOPEN
|
||||
}
|
||||
|
||||
// Delete implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) Delete(name string, dirSync bool) error {
|
||||
return _IOERR_DELETE
|
||||
}
|
||||
|
||||
// Access implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) Access(name string, flag AccessFlag) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// FullPathname implements the [VFS] interface.
|
||||
func (vfs MemoryVFS) FullPathname(name string) (string, error) {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
const memSectorSize = 65536
|
||||
|
||||
// A MemoryDB is a [MemoryVFS] database.
|
||||
//
|
||||
// A MemoryDB is safe to access concurrently through multiple SQLite connections.
|
||||
type MemoryDB struct {
|
||||
// +checklocks:dataMtx
|
||||
MaxSize int64
|
||||
|
||||
// +checklocks:dataMtx
|
||||
data []*[memSectorSize]byte
|
||||
// +checklocks:dataMtx
|
||||
size int64
|
||||
|
||||
// +checklocks:lockMtx
|
||||
pending *memoryFile
|
||||
// +checklocks:lockMtx
|
||||
reserved *memoryFile
|
||||
// +checklocks:lockMtx
|
||||
shared int
|
||||
|
||||
lockMtx sync.Mutex
|
||||
dataMtx sync.RWMutex
|
||||
}
|
||||
|
||||
// NewMemoryDB creates a new MemoryDB using mem as its initial contents.
|
||||
// The new MemoryDB takes ownership of mem, and the caller should not use mem after this call.
|
||||
func NewMemoryDB(mem []byte) *MemoryDB {
|
||||
m := new(MemoryDB)
|
||||
m.size = int64(len(mem))
|
||||
|
||||
sectors := divRoundUp(m.size, memSectorSize)
|
||||
m.data = make([]*[memSectorSize]byte, sectors)
|
||||
for i := range m.data {
|
||||
sector := mem[i*memSectorSize:]
|
||||
if len(sector) >= memSectorSize {
|
||||
m.data[i] = (*[memSectorSize]byte)(sector)
|
||||
} else {
|
||||
m.data[i] = new([memSectorSize]byte)
|
||||
copy((*m.data[i])[:], sector)
|
||||
}
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
type memoryFile struct {
|
||||
*MemoryDB
|
||||
lock LockLevel
|
||||
readOnly bool
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ FileLockState = &memoryFile{}
|
||||
_ FileSizeHint = &memoryFile{}
|
||||
)
|
||||
|
||||
// Close implements the [File] and [io.Closer] interfaces.
|
||||
func (m *memoryFile) Close() error {
|
||||
return m.Unlock(LOCK_NONE)
|
||||
}
|
||||
|
||||
// ReadAt implements the [File] and [io.ReaderAt] interfaces.
|
||||
func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) {
|
||||
m.dataMtx.RLock()
|
||||
defer m.dataMtx.RUnlock()
|
||||
|
||||
if off >= m.size {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
base := off / memSectorSize
|
||||
rest := off % memSectorSize
|
||||
have := int64(memSectorSize)
|
||||
if base == int64(len(m.data))-1 {
|
||||
have = modRoundUp(m.size, memSectorSize)
|
||||
}
|
||||
n = copy(b, (*m.data[base])[rest:have])
|
||||
if n < len(b) {
|
||||
// Assume reads are page aligned.
|
||||
return 0, io.ErrNoProgress
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// WriteAt implements the [File] and [io.WriterAt] interfaces.
|
||||
func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
|
||||
size := off + int64(len(b))
|
||||
if m.MaxSize > 0 && size > m.MaxSize {
|
||||
return 0, _FULL
|
||||
}
|
||||
|
||||
base := off / memSectorSize
|
||||
rest := off % memSectorSize
|
||||
for base >= int64(len(m.data)) {
|
||||
m.data = append(m.data, new([memSectorSize]byte))
|
||||
}
|
||||
n = copy((*m.data[base])[rest:], b)
|
||||
if n < len(b) {
|
||||
// Assume writes are page aligned.
|
||||
return 0, io.ErrShortWrite
|
||||
}
|
||||
if size > m.size {
|
||||
m.size = size
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Truncate implements the [File] interface.
|
||||
func (m *memoryFile) Truncate(size int64) error {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
return m.truncate(size)
|
||||
}
|
||||
|
||||
// +checklocks:m.dataMtx
|
||||
func (m *memoryFile) truncate(size int64) error {
|
||||
if m.MaxSize > 0 && size > m.MaxSize {
|
||||
return _FULL
|
||||
}
|
||||
if size < m.size {
|
||||
base := size / memSectorSize
|
||||
rest := size % memSectorSize
|
||||
if rest != 0 {
|
||||
clear((*m.data[base])[rest:])
|
||||
}
|
||||
}
|
||||
sectors := divRoundUp(size, memSectorSize)
|
||||
for sectors > int64(len(m.data)) {
|
||||
m.data = append(m.data, new([memSectorSize]byte))
|
||||
}
|
||||
clear(m.data[sectors:])
|
||||
m.data = m.data[:sectors]
|
||||
m.size = size
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync implements the [File] interface.
|
||||
func (*memoryFile) Sync(flag SyncFlag) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Size implements the [File] interface.
|
||||
func (m *memoryFile) Size() (int64, error) {
|
||||
m.dataMtx.RLock()
|
||||
defer m.dataMtx.RUnlock()
|
||||
return m.size, nil
|
||||
}
|
||||
|
||||
// Lock implements the [File] interface.
|
||||
func (m *memoryFile) Lock(lock LockLevel) error {
|
||||
if m.lock >= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.readOnly && lock >= LOCK_RESERVED {
|
||||
return _IOERR_LOCK
|
||||
}
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
deadline := time.Now().Add(time.Millisecond)
|
||||
|
||||
switch lock {
|
||||
case LOCK_SHARED:
|
||||
for m.pending != nil {
|
||||
if time.Now().After(deadline) {
|
||||
return _BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
}
|
||||
m.shared++
|
||||
|
||||
case LOCK_RESERVED:
|
||||
if m.reserved != nil {
|
||||
return _BUSY
|
||||
}
|
||||
m.reserved = m
|
||||
|
||||
case LOCK_EXCLUSIVE:
|
||||
if m.lock < LOCK_PENDING {
|
||||
if m.pending != nil {
|
||||
return _BUSY
|
||||
}
|
||||
m.lock = LOCK_PENDING
|
||||
m.pending = m
|
||||
}
|
||||
|
||||
for m.shared > 1 {
|
||||
if time.Now().After(deadline) {
|
||||
return _BUSY
|
||||
}
|
||||
m.lockMtx.Unlock()
|
||||
runtime.Gosched()
|
||||
m.lockMtx.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
m.lock = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlock implements the [File] interface.
|
||||
func (m *memoryFile) Unlock(lock LockLevel) error {
|
||||
if m.lock <= lock {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
|
||||
if m.pending == m {
|
||||
m.pending = nil
|
||||
}
|
||||
if m.reserved == m {
|
||||
m.reserved = nil
|
||||
}
|
||||
if lock < LOCK_SHARED {
|
||||
m.shared--
|
||||
}
|
||||
m.lock = lock
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckReservedLock implements the [File] interface.
|
||||
func (m *memoryFile) CheckReservedLock() (bool, error) {
|
||||
if m.lock >= LOCK_RESERVED {
|
||||
return true, nil
|
||||
}
|
||||
m.lockMtx.Lock()
|
||||
defer m.lockMtx.Unlock()
|
||||
return m.reserved != nil, nil
|
||||
}
|
||||
|
||||
// SectorSize implements the [File] interface.
|
||||
func (*memoryFile) SectorSize() int {
|
||||
return memSectorSize
|
||||
}
|
||||
|
||||
// DeviceCharacteristics implements the [File] interface.
|
||||
func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic {
|
||||
return IOCAP_ATOMIC |
|
||||
IOCAP_SEQUENTIAL |
|
||||
IOCAP_SAFE_APPEND |
|
||||
IOCAP_POWERSAFE_OVERWRITE
|
||||
}
|
||||
|
||||
// SizeHint implements the [FileSizeHint] interface.
|
||||
func (m *memoryFile) SizeHint(size int64) error {
|
||||
m.dataMtx.Lock()
|
||||
defer m.dataMtx.Unlock()
|
||||
if size > m.size {
|
||||
return m.truncate(size)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LockState implements the [FileLockState] interface.
|
||||
func (m *memoryFile) LockState() LockLevel {
|
||||
return m.lock
|
||||
}
|
||||
|
||||
func divRoundUp(a, b int64) int64 {
|
||||
return (a + b - 1) / b
|
||||
}
|
||||
|
||||
func modRoundUp(a, b int64) int64 {
|
||||
return b - (b-a%b)%b
|
||||
}
|
||||
|
||||
func clear[T any](b []T) {
|
||||
var zero T
|
||||
for i := range b {
|
||||
b[i] = zero
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/sqlite3memdb"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
"github.com/tetratelabs/wazero"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
@@ -32,7 +33,6 @@ var (
|
||||
rt wazero.Runtime
|
||||
module wazero.CompiledModule
|
||||
instances atomic.Uint64
|
||||
memory = sqlite3vfs.MemoryVFS{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -52,8 +52,6 @@ func init() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sqlite3vfs.Register("memvfs", memory)
|
||||
}
|
||||
|
||||
func config(ctx context.Context) wazero.ModuleConfig {
|
||||
@@ -75,28 +73,11 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 {
|
||||
buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr)
|
||||
buf = buf[:bytes.IndexByte(buf, 0)]
|
||||
|
||||
var memvfs, journal, timeout bool
|
||||
args := strings.Split(string(buf), " ")
|
||||
for i := range args {
|
||||
args[i] = strings.Trim(args[i], `"`)
|
||||
switch args[i] {
|
||||
case "memvfs":
|
||||
memvfs = true
|
||||
case "--timeout":
|
||||
timeout = true
|
||||
case "--journalmode":
|
||||
journal = true
|
||||
}
|
||||
}
|
||||
args = args[:len(args)-1]
|
||||
if memvfs {
|
||||
if !timeout {
|
||||
args = append(args, "--timeout", "1000")
|
||||
}
|
||||
if !journal {
|
||||
args = append(args, "--journalmode", "memory")
|
||||
}
|
||||
}
|
||||
|
||||
cfg := config(ctx).WithArgs(args...)
|
||||
go func() {
|
||||
@@ -172,13 +153,11 @@ func Test_multiwrite01(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_config01_memory(t *testing.T) {
|
||||
memory["test.db"] = new(sqlite3vfs.MemoryDB)
|
||||
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
|
||||
cfg := config(ctx).WithArgs("mptest", "test.db",
|
||||
"config01.test",
|
||||
"--vfs", "memvfs",
|
||||
"--timeout", "1000",
|
||||
"--journalmode", "memory")
|
||||
"--vfs", "memdb",
|
||||
"--timeout", "1000")
|
||||
mod, err := rt.InstantiateModule(ctx, module, cfg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -192,13 +171,11 @@ func Test_multiwrite01_memory(t *testing.T) {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
memory["test.db"] = new(sqlite3vfs.MemoryDB)
|
||||
ctx, vfs := sqlite3vfs.NewContext(newContext(t))
|
||||
cfg := config(ctx).WithArgs("mptest", "test.db",
|
||||
cfg := config(ctx).WithArgs("mptest", "/test.db",
|
||||
"multiwrite01.test",
|
||||
"--vfs", "memvfs",
|
||||
"--timeout", "1000",
|
||||
"--journalmode", "memory")
|
||||
"--vfs", "memdb",
|
||||
"--timeout", "1000")
|
||||
mod, err := rt.InstantiateModule(ctx, module, cfg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
@@ -59,7 +59,7 @@ type vfsState struct {
|
||||
//
|
||||
// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly.
|
||||
func NewContext(ctx context.Context) (context.Context, io.Closer) {
|
||||
vfs := &vfsState{}
|
||||
vfs := new(vfsState)
|
||||
return context.WithValue(ctx, vfsKey{}, vfs), vfs
|
||||
}
|
||||
|
||||
@@ -457,3 +457,9 @@ func vfsErrorCode(err error, def _ErrorCode) _ErrorCode {
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func clear(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
_ "github.com/ncruces/go-sqlite3/sqlite3memdb"
|
||||
)
|
||||
|
||||
func TestDB_memory(t *testing.T) {
|
||||
@@ -19,12 +19,8 @@ func TestDB_file(t *testing.T) {
|
||||
testDB(t, filepath.Join(t.TempDir(), "test.db"))
|
||||
}
|
||||
|
||||
func TestDB_VFS(t *testing.T) {
|
||||
sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{
|
||||
"test.db": &sqlite3vfs.MemoryDB{},
|
||||
})
|
||||
defer sqlite3vfs.Unregister("memvfs")
|
||||
testDB(t, "file:test.db?vfs=memvfs")
|
||||
func TestDB_vfs(t *testing.T) {
|
||||
testDB(t, "file:test.db?vfs=memdb")
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, name string) {
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
_ "github.com/ncruces/go-sqlite3/sqlite3memdb"
|
||||
)
|
||||
|
||||
func TestParallel(t *testing.T) {
|
||||
@@ -40,12 +40,7 @@ func TestMemory(t *testing.T) {
|
||||
iter = 5000
|
||||
}
|
||||
|
||||
sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{
|
||||
"test.db": &sqlite3vfs.MemoryDB{},
|
||||
})
|
||||
defer sqlite3vfs.Unregister("memvfs")
|
||||
|
||||
name := "file:test.db?vfs=memvfs" +
|
||||
name := "file:/test.db?vfs=memdb" +
|
||||
"&_pragma=busy_timeout(10000)" +
|
||||
"&_pragma=locking_mode(normal)" +
|
||||
"&_pragma=journal_mode(memory)" +
|
||||
|
||||
@@ -2,27 +2,19 @@ package tests
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3memdb"
|
||||
"github.com/ncruces/go-sqlite3/sqlite3vfs"
|
||||
)
|
||||
|
||||
//go:embed testdata/test.db
|
||||
var testdata string
|
||||
|
||||
func TestMemoryVFS_Open_notfound(t *testing.T) {
|
||||
sqlite3vfs.Register("memory", sqlite3vfs.MemoryVFS{
|
||||
"test.db": &sqlite3vfs.MemoryDB{},
|
||||
})
|
||||
defer sqlite3vfs.Unregister("memory")
|
||||
sqlite3memdb.Delete("demo.db")
|
||||
|
||||
_, err := sqlite3.Open("file:demo.db?vfs=memory&mode=ro")
|
||||
_, err := sqlite3.Open("file:/demo.db?vfs=memdb&mode=ro")
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
@@ -31,36 +23,8 @@ func TestMemoryVFS_Open_notfound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryVFS_Open_errors(t *testing.T) {
|
||||
sqlite3vfs.Register("memory", sqlite3vfs.MemoryVFS{
|
||||
"test.db": &sqlite3vfs.MemoryDB{MaxSize: 65536},
|
||||
})
|
||||
defer sqlite3vfs.Unregister("memory")
|
||||
|
||||
db, err := sqlite3.Open("file:test.db?vfs=memory")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO test VALUES (zeroblob(65536))`)
|
||||
if err == nil {
|
||||
t.Error("want error")
|
||||
}
|
||||
if !errors.Is(err, sqlite3.FULL) {
|
||||
t.Errorf("got %v, want sqlite3.FULL", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReaderVFS_Open_notfound(t *testing.T) {
|
||||
sqlite3vfs.Register("reader", sqlite3vfs.ReaderVFS{
|
||||
"test.db": sqlite3vfs.NewSizeReaderAt(strings.NewReader(testdata)),
|
||||
})
|
||||
sqlite3vfs.Register("reader", sqlite3vfs.ReaderVFS{})
|
||||
defer sqlite3vfs.Unregister("reader")
|
||||
|
||||
_, err := sqlite3.Open("file:demo.db?vfs=reader&mode=ro")
|
||||
|
||||
Reference in New Issue
Block a user