diff --git a/module.go b/module.go index 70d6270..3d3221a 100644 --- a/module.go +++ b/module.go @@ -83,7 +83,7 @@ type module struct { } func newModule(mod api.Module) (m *module, err error) { - m = &module{} + m = new(module) m.mod = mod m.ctx, m.vfs = sqlite3vfs.NewContext(context.Background()) diff --git a/sqlite3memdb/api.go b/sqlite3memdb/api.go new file mode 100644 index 0000000..c41978b --- /dev/null +++ b/sqlite3memdb/api.go @@ -0,0 +1,58 @@ +// Package sqlite3memdb implements the "memdb" SQLite VFS. +// +// The "memdb" [sqlite3vfs.VFS] allows the same in-memory database to be shared +// among multiple database connections in the same process, +// as long as the database name begins with "/". +// +// Importing package sqlite3memdb registers the VFS. +// +// import _ "github.com/ncruces/go-sqlite3/sqlite3memdb" +package sqlite3memdb + +import ( + "sync" + + "github.com/ncruces/go-sqlite3/sqlite3vfs" +) + +func init() { + sqlite3vfs.Register("memdb", vfs{}) +} + +var ( + memoryMtx sync.Mutex + memoryDBs = map[string]*dbase{} +) + +// Create creates a shared memory database, +// using data as its initial contents. +// The new database takes ownership of data, +// and the caller should not use data after this call. +func Create(name string, data []byte) { + memoryMtx.Lock() + defer memoryMtx.Unlock() + + db := new(dbase) + db.size = int64(len(data)) + + sectors := divRoundUp(db.size, sectorSize) + db.data = make([]*[sectorSize]byte, sectors) + for i := range db.data { + sector := data[i*sectorSize:] + if len(sector) >= sectorSize { + db.data[i] = (*[sectorSize]byte)(sector) + } else { + db.data[i] = new([sectorSize]byte) + copy((*db.data[i])[:], sector) + } + } + + memoryDBs[name] = db +} + +// Delete deletes a shared memory database. +func Delete(name string) { + memoryMtx.Lock() + defer memoryMtx.Unlock() + delete(memoryDBs, name) +} diff --git a/sqlite3memdb/example_test.go b/sqlite3memdb/example_test.go new file mode 100644 index 0000000..52ec70e --- /dev/null +++ b/sqlite3memdb/example_test.go @@ -0,0 +1,51 @@ +package sqlite3memdb_test + +import ( + "database/sql" + "fmt" + "log" + + _ "embed" + + _ "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/sqlite3memdb" +) + +//go:embed testdata/test.db +var testDB []byte + +func Example() { + sqlite3memdb.Create("test.db", testDB) + + db, err := sql.Open("sqlite3", "file:/test.db?vfs=memdb") + if err != nil { + log.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`) + if err != nil { + log.Fatal(err) + } + + rows, err := db.Query(`SELECT id, name FROM users`) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id, name string + err = rows.Scan(&id, &name) + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s %s\n", id, name) + } + // Output: + // 0 go + // 1 zig + // 2 whatever + // 3 rust +} diff --git a/sqlite3memdb/memdb.go b/sqlite3memdb/memdb.go new file mode 100644 index 0000000..3af79bc --- /dev/null +++ b/sqlite3memdb/memdb.go @@ -0,0 +1,292 @@ +package sqlite3memdb + +import ( + "io" + "runtime" + "strings" + "sync" + "time" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/sqlite3vfs" +) + +type vfs struct{} + +func (vfs) Open(name string, flags sqlite3vfs.OpenFlag) (sqlite3vfs.File, sqlite3vfs.OpenFlag, error) { + if flags&sqlite3vfs.OPEN_MAIN_DB == 0 { + return nil, flags, sqlite3.CANTOPEN + } + + var db *dbase + + shared := strings.HasPrefix(name, "/") + if shared { + memoryMtx.Lock() + defer memoryMtx.Unlock() + db = memoryDBs[name[1:]] + } + if db == nil { + if flags&sqlite3vfs.OPEN_CREATE == 0 { + return nil, flags, sqlite3.CANTOPEN + } + db = new(dbase) + } + if shared { + memoryDBs[name[1:]] = db + } + + return &file{ + dbase: db, + readOnly: flags&sqlite3vfs.OPEN_READONLY != 0, + }, flags | sqlite3vfs.OPEN_MEMORY, nil +} + +func (vfs) Delete(name string, dirSync bool) error { + return sqlite3.IOERR_DELETE +} + +func (vfs) Access(name string, flag sqlite3vfs.AccessFlag) (bool, error) { + return false, nil +} + +func (vfs) FullPathname(name string) (string, error) { + return name, nil +} + +const sectorSize = 65536 + +type dbase struct { + // +checklocks:lockMtx + pending *file + // +checklocks:lockMtx + reserved *file + + // +checklocks:dataMtx + data []*[sectorSize]byte + + // +checklocks:dataMtx + size int64 + + // +checklocks:lockMtx + shared int + + lockMtx sync.Mutex + dataMtx sync.RWMutex +} + +type file struct { + *dbase + lock sqlite3vfs.LockLevel + readOnly bool +} + +var ( + // Ensure these interfaces are implemented: + _ sqlite3vfs.FileLockState = &file{} + _ sqlite3vfs.FileSizeHint = &file{} +) + +func (m *file) Close() error { + return m.Unlock(sqlite3vfs.LOCK_NONE) +} + +func (m *file) ReadAt(b []byte, off int64) (n int, err error) { + m.dataMtx.RLock() + defer m.dataMtx.RUnlock() + + if off >= m.size { + return 0, io.EOF + } + + base := off / sectorSize + rest := off % sectorSize + have := int64(sectorSize) + if base == int64(len(m.data))-1 { + have = modRoundUp(m.size, sectorSize) + } + n = copy(b, (*m.data[base])[rest:have]) + if n < len(b) { + // Assume reads are page aligned. + return 0, io.ErrNoProgress + } + return n, nil +} + +func (m *file) WriteAt(b []byte, off int64) (n int, err error) { + m.dataMtx.Lock() + defer m.dataMtx.Unlock() + + base := off / sectorSize + rest := off % sectorSize + for base >= int64(len(m.data)) { + m.data = append(m.data, new([sectorSize]byte)) + } + n = copy((*m.data[base])[rest:], b) + if n < len(b) { + // Assume writes are page aligned. + return 0, io.ErrShortWrite + } + if size := off + int64(len(b)); size > m.size { + m.size = size + } + return n, nil +} + +func (m *file) Truncate(size int64) error { + m.dataMtx.Lock() + defer m.dataMtx.Unlock() + return m.truncate(size) +} + +// +checklocks:m.dataMtx +func (m *file) truncate(size int64) error { + if size < m.size { + base := size / sectorSize + rest := size % sectorSize + if rest != 0 { + clear((*m.data[base])[rest:]) + } + } + sectors := divRoundUp(size, sectorSize) + for sectors > int64(len(m.data)) { + m.data = append(m.data, new([sectorSize]byte)) + } + clear(m.data[sectors:]) + m.data = m.data[:sectors] + m.size = size + return nil +} + +func (*file) Sync(flag sqlite3vfs.SyncFlag) error { + return nil +} + +func (m *file) Size() (int64, error) { + m.dataMtx.RLock() + defer m.dataMtx.RUnlock() + return m.size, nil +} + +func (m *file) Lock(lock sqlite3vfs.LockLevel) error { + if m.lock >= lock { + return nil + } + + if m.readOnly && lock >= sqlite3vfs.LOCK_RESERVED { + return sqlite3.IOERR_LOCK + } + + m.lockMtx.Lock() + defer m.lockMtx.Unlock() + deadline := time.Now().Add(time.Millisecond) + + switch lock { + case sqlite3vfs.LOCK_SHARED: + for m.pending != nil { + if time.Now().After(deadline) { + return sqlite3.BUSY + } + m.lockMtx.Unlock() + runtime.Gosched() + m.lockMtx.Lock() + } + m.shared++ + + case sqlite3vfs.LOCK_RESERVED: + if m.reserved != nil { + return sqlite3.BUSY + } + m.reserved = m + + case sqlite3vfs.LOCK_EXCLUSIVE: + if m.lock < sqlite3vfs.LOCK_PENDING { + if m.pending != nil { + return sqlite3.BUSY + } + m.lock = sqlite3vfs.LOCK_PENDING + m.pending = m + } + + for m.shared > 1 { + if time.Now().After(deadline) { + return sqlite3.BUSY + } + m.lockMtx.Unlock() + runtime.Gosched() + m.lockMtx.Lock() + } + } + + m.lock = lock + return nil +} + +func (m *file) Unlock(lock sqlite3vfs.LockLevel) error { + if m.lock <= lock { + return nil + } + + m.lockMtx.Lock() + defer m.lockMtx.Unlock() + + if m.pending == m { + m.pending = nil + } + if m.reserved == m { + m.reserved = nil + } + if lock < sqlite3vfs.LOCK_SHARED { + m.shared-- + } + m.lock = lock + return nil +} + +func (m *file) CheckReservedLock() (bool, error) { + if m.lock >= sqlite3vfs.LOCK_RESERVED { + return true, nil + } + m.lockMtx.Lock() + defer m.lockMtx.Unlock() + return m.reserved != nil, nil +} + +func (*file) SectorSize() int { + return sectorSize +} + +func (*file) DeviceCharacteristics() sqlite3vfs.DeviceCharacteristic { + return sqlite3vfs.IOCAP_ATOMIC | + sqlite3vfs.IOCAP_SEQUENTIAL | + sqlite3vfs.IOCAP_SAFE_APPEND | + sqlite3vfs.IOCAP_POWERSAFE_OVERWRITE +} + +func (m *file) SizeHint(size int64) error { + m.dataMtx.Lock() + defer m.dataMtx.Unlock() + if size > m.size { + return m.truncate(size) + } + return nil +} + +func (m *file) LockState() sqlite3vfs.LockLevel { + return m.lock +} + +func divRoundUp(a, b int64) int64 { + return (a + b - 1) / b +} + +func modRoundUp(a, b int64) int64 { + return b - (b-a%b)%b +} + +func clear[T any](b []T) { + var zero T + for i := range b { + b[i] = zero + } +} diff --git a/tests/testdata/test.db b/sqlite3memdb/testdata/test.db similarity index 100% rename from tests/testdata/test.db rename to sqlite3memdb/testdata/test.db diff --git a/sqlite3vfs/const.go b/sqlite3vfs/const.go index 02cab90..756be70 100644 --- a/sqlite3vfs/const.go +++ b/sqlite3vfs/const.go @@ -22,7 +22,6 @@ const ( _READONLY _ErrorCode = util.READONLY _IOERR _ErrorCode = util.IOERR _NOTFOUND _ErrorCode = util.NOTFOUND - _FULL _ErrorCode = util.FULL _CANTOPEN _ErrorCode = util.CANTOPEN _IOERR_READ _ErrorCode = util.IOERR_READ _IOERR_SHORT_READ _ErrorCode = util.IOERR_SHORT_READ diff --git a/sqlite3vfs/example_test.go b/sqlite3vfs/example_test.go index fb526a1..ff4507b 100644 --- a/sqlite3vfs/example_test.go +++ b/sqlite3vfs/example_test.go @@ -17,43 +17,6 @@ import ( //go:embed testdata/test.db var testDB []byte -func ExampleMemoryVFS_embed() { - sqlite3vfs.Register("memory", sqlite3vfs.MemoryVFS{ - "test.db": sqlite3vfs.NewMemoryDB(testDB), - }) - - db, err := sql.Open("sqlite3", "file:test.db?vfs=memory") - if err != nil { - log.Fatal(err) - } - defer db.Close() - - _, err = db.Exec(`INSERT INTO users (id, name) VALUES (3, 'rust')`) - if err != nil { - log.Fatal(err) - } - - rows, err := db.Query(`SELECT id, name FROM users`) - if err != nil { - log.Fatal(err) - } - defer rows.Close() - - for rows.Next() { - var id, name string - err = rows.Scan(&id, &name) - if err != nil { - log.Fatal(err) - } - fmt.Printf("%s %s\n", id, name) - } - // Output: - // 0 go - // 1 zig - // 2 whatever - // 3 rust -} - func ExampleReaderVFS_http() { sqlite3vfs.Register("httpvfs", sqlite3vfs.ReaderVFS{ "demo.db": httpreadat.New("https://www.sanford.io/demo.db"), diff --git a/sqlite3vfs/memory.go b/sqlite3vfs/memory.go deleted file mode 100644 index 697cc1f..0000000 --- a/sqlite3vfs/memory.go +++ /dev/null @@ -1,325 +0,0 @@ -package sqlite3vfs - -import ( - "io" - "runtime" - "sync" - "time" -) - -// A MemoryVFS is a [VFS] for memory databases. -type MemoryVFS map[string]*MemoryDB - -var _ VFS = MemoryVFS{} - -// Open implements the [VFS] interface. -func (vfs MemoryVFS) Open(name string, flags OpenFlag) (File, OpenFlag, error) { - if flags&OPEN_MAIN_DB == 0 { - return nil, flags, _CANTOPEN - } - if db, ok := vfs[name]; ok { - return &memoryFile{ - MemoryDB: db, - readOnly: flags&OPEN_READONLY != 0, - }, flags | OPEN_MEMORY, nil - } - return nil, flags, _CANTOPEN -} - -// Delete implements the [VFS] interface. -func (vfs MemoryVFS) Delete(name string, dirSync bool) error { - return _IOERR_DELETE -} - -// Access implements the [VFS] interface. -func (vfs MemoryVFS) Access(name string, flag AccessFlag) (bool, error) { - return false, nil -} - -// FullPathname implements the [VFS] interface. -func (vfs MemoryVFS) FullPathname(name string) (string, error) { - return name, nil -} - -const memSectorSize = 65536 - -// A MemoryDB is a [MemoryVFS] database. -// -// A MemoryDB is safe to access concurrently through multiple SQLite connections. -type MemoryDB struct { - // +checklocks:dataMtx - MaxSize int64 - - // +checklocks:dataMtx - data []*[memSectorSize]byte - // +checklocks:dataMtx - size int64 - - // +checklocks:lockMtx - pending *memoryFile - // +checklocks:lockMtx - reserved *memoryFile - // +checklocks:lockMtx - shared int - - lockMtx sync.Mutex - dataMtx sync.RWMutex -} - -// NewMemoryDB creates a new MemoryDB using mem as its initial contents. -// The new MemoryDB takes ownership of mem, and the caller should not use mem after this call. -func NewMemoryDB(mem []byte) *MemoryDB { - m := new(MemoryDB) - m.size = int64(len(mem)) - - sectors := divRoundUp(m.size, memSectorSize) - m.data = make([]*[memSectorSize]byte, sectors) - for i := range m.data { - sector := mem[i*memSectorSize:] - if len(sector) >= memSectorSize { - m.data[i] = (*[memSectorSize]byte)(sector) - } else { - m.data[i] = new([memSectorSize]byte) - copy((*m.data[i])[:], sector) - } - } - - return m -} - -type memoryFile struct { - *MemoryDB - lock LockLevel - readOnly bool -} - -var ( - // Ensure these interfaces are implemented: - _ FileLockState = &memoryFile{} - _ FileSizeHint = &memoryFile{} -) - -// Close implements the [File] and [io.Closer] interfaces. -func (m *memoryFile) Close() error { - return m.Unlock(LOCK_NONE) -} - -// ReadAt implements the [File] and [io.ReaderAt] interfaces. -func (m *memoryFile) ReadAt(b []byte, off int64) (n int, err error) { - m.dataMtx.RLock() - defer m.dataMtx.RUnlock() - - if off >= m.size { - return 0, io.EOF - } - - base := off / memSectorSize - rest := off % memSectorSize - have := int64(memSectorSize) - if base == int64(len(m.data))-1 { - have = modRoundUp(m.size, memSectorSize) - } - n = copy(b, (*m.data[base])[rest:have]) - if n < len(b) { - // Assume reads are page aligned. - return 0, io.ErrNoProgress - } - return n, nil -} - -// WriteAt implements the [File] and [io.WriterAt] interfaces. -func (m *memoryFile) WriteAt(b []byte, off int64) (n int, err error) { - m.dataMtx.Lock() - defer m.dataMtx.Unlock() - - size := off + int64(len(b)) - if m.MaxSize > 0 && size > m.MaxSize { - return 0, _FULL - } - - base := off / memSectorSize - rest := off % memSectorSize - for base >= int64(len(m.data)) { - m.data = append(m.data, new([memSectorSize]byte)) - } - n = copy((*m.data[base])[rest:], b) - if n < len(b) { - // Assume writes are page aligned. - return 0, io.ErrShortWrite - } - if size > m.size { - m.size = size - } - return n, nil -} - -// Truncate implements the [File] interface. -func (m *memoryFile) Truncate(size int64) error { - m.dataMtx.Lock() - defer m.dataMtx.Unlock() - return m.truncate(size) -} - -// +checklocks:m.dataMtx -func (m *memoryFile) truncate(size int64) error { - if m.MaxSize > 0 && size > m.MaxSize { - return _FULL - } - if size < m.size { - base := size / memSectorSize - rest := size % memSectorSize - if rest != 0 { - clear((*m.data[base])[rest:]) - } - } - sectors := divRoundUp(size, memSectorSize) - for sectors > int64(len(m.data)) { - m.data = append(m.data, new([memSectorSize]byte)) - } - clear(m.data[sectors:]) - m.data = m.data[:sectors] - m.size = size - return nil -} - -// Sync implements the [File] interface. -func (*memoryFile) Sync(flag SyncFlag) error { - return nil -} - -// Size implements the [File] interface. -func (m *memoryFile) Size() (int64, error) { - m.dataMtx.RLock() - defer m.dataMtx.RUnlock() - return m.size, nil -} - -// Lock implements the [File] interface. -func (m *memoryFile) Lock(lock LockLevel) error { - if m.lock >= lock { - return nil - } - - if m.readOnly && lock >= LOCK_RESERVED { - return _IOERR_LOCK - } - - m.lockMtx.Lock() - defer m.lockMtx.Unlock() - deadline := time.Now().Add(time.Millisecond) - - switch lock { - case LOCK_SHARED: - for m.pending != nil { - if time.Now().After(deadline) { - return _BUSY - } - m.lockMtx.Unlock() - runtime.Gosched() - m.lockMtx.Lock() - } - m.shared++ - - case LOCK_RESERVED: - if m.reserved != nil { - return _BUSY - } - m.reserved = m - - case LOCK_EXCLUSIVE: - if m.lock < LOCK_PENDING { - if m.pending != nil { - return _BUSY - } - m.lock = LOCK_PENDING - m.pending = m - } - - for m.shared > 1 { - if time.Now().After(deadline) { - return _BUSY - } - m.lockMtx.Unlock() - runtime.Gosched() - m.lockMtx.Lock() - } - } - - m.lock = lock - return nil -} - -// Unlock implements the [File] interface. -func (m *memoryFile) Unlock(lock LockLevel) error { - if m.lock <= lock { - return nil - } - - m.lockMtx.Lock() - defer m.lockMtx.Unlock() - - if m.pending == m { - m.pending = nil - } - if m.reserved == m { - m.reserved = nil - } - if lock < LOCK_SHARED { - m.shared-- - } - m.lock = lock - return nil -} - -// CheckReservedLock implements the [File] interface. -func (m *memoryFile) CheckReservedLock() (bool, error) { - if m.lock >= LOCK_RESERVED { - return true, nil - } - m.lockMtx.Lock() - defer m.lockMtx.Unlock() - return m.reserved != nil, nil -} - -// SectorSize implements the [File] interface. -func (*memoryFile) SectorSize() int { - return memSectorSize -} - -// DeviceCharacteristics implements the [File] interface. -func (*memoryFile) DeviceCharacteristics() DeviceCharacteristic { - return IOCAP_ATOMIC | - IOCAP_SEQUENTIAL | - IOCAP_SAFE_APPEND | - IOCAP_POWERSAFE_OVERWRITE -} - -// SizeHint implements the [FileSizeHint] interface. -func (m *memoryFile) SizeHint(size int64) error { - m.dataMtx.Lock() - defer m.dataMtx.Unlock() - if size > m.size { - return m.truncate(size) - } - return nil -} - -// LockState implements the [FileLockState] interface. -func (m *memoryFile) LockState() LockLevel { - return m.lock -} - -func divRoundUp(a, b int64) int64 { - return (a + b - 1) / b -} - -func modRoundUp(a, b int64) int64 { - return b - (b-a%b)%b -} - -func clear[T any](b []T) { - var zero T - for i := range b { - b[i] = zero - } -} diff --git a/sqlite3vfs/tests/mptest/mptest_test.go b/sqlite3vfs/tests/mptest/mptest_test.go index e264d12..73ed79f 100644 --- a/sqlite3vfs/tests/mptest/mptest_test.go +++ b/sqlite3vfs/tests/mptest/mptest_test.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "testing" + _ "github.com/ncruces/go-sqlite3/sqlite3memdb" "github.com/ncruces/go-sqlite3/sqlite3vfs" "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" @@ -32,7 +33,6 @@ var ( rt wazero.Runtime module wazero.CompiledModule instances atomic.Uint64 - memory = sqlite3vfs.MemoryVFS{} ) func init() { @@ -52,8 +52,6 @@ func init() { if err != nil { panic(err) } - - sqlite3vfs.Register("memvfs", memory) } func config(ctx context.Context) wazero.ModuleConfig { @@ -75,28 +73,11 @@ func system(ctx context.Context, mod api.Module, ptr uint32) uint32 { buf, _ := mod.Memory().Read(ptr, mod.Memory().Size()-ptr) buf = buf[:bytes.IndexByte(buf, 0)] - var memvfs, journal, timeout bool args := strings.Split(string(buf), " ") for i := range args { args[i] = strings.Trim(args[i], `"`) - switch args[i] { - case "memvfs": - memvfs = true - case "--timeout": - timeout = true - case "--journalmode": - journal = true - } } args = args[:len(args)-1] - if memvfs { - if !timeout { - args = append(args, "--timeout", "1000") - } - if !journal { - args = append(args, "--journalmode", "memory") - } - } cfg := config(ctx).WithArgs(args...) go func() { @@ -172,13 +153,11 @@ func Test_multiwrite01(t *testing.T) { } func Test_config01_memory(t *testing.T) { - memory["test.db"] = new(sqlite3vfs.MemoryDB) ctx, vfs := sqlite3vfs.NewContext(newContext(t)) cfg := config(ctx).WithArgs("mptest", "test.db", "config01.test", - "--vfs", "memvfs", - "--timeout", "1000", - "--journalmode", "memory") + "--vfs", "memdb", + "--timeout", "1000") mod, err := rt.InstantiateModule(ctx, module, cfg) if err != nil { t.Error(err) @@ -192,13 +171,11 @@ func Test_multiwrite01_memory(t *testing.T) { t.Skip("skipping in short mode") } - memory["test.db"] = new(sqlite3vfs.MemoryDB) ctx, vfs := sqlite3vfs.NewContext(newContext(t)) - cfg := config(ctx).WithArgs("mptest", "test.db", + cfg := config(ctx).WithArgs("mptest", "/test.db", "multiwrite01.test", - "--vfs", "memvfs", - "--timeout", "1000", - "--journalmode", "memory") + "--vfs", "memdb", + "--timeout", "1000") mod, err := rt.InstantiateModule(ctx, module, cfg) if err != nil { t.Error(err) diff --git a/sqlite3vfs/vfs.go b/sqlite3vfs/vfs.go index fbd9f3a..79bc6a4 100644 --- a/sqlite3vfs/vfs.go +++ b/sqlite3vfs/vfs.go @@ -59,7 +59,7 @@ type vfsState struct { // // Users of the [github.com/ncruces/go-sqlite3] package need not call this directly. func NewContext(ctx context.Context) (context.Context, io.Closer) { - vfs := &vfsState{} + vfs := new(vfsState) return context.WithValue(ctx, vfsKey{}, vfs), vfs } @@ -457,3 +457,9 @@ func vfsErrorCode(err error, def _ErrorCode) _ErrorCode { } return def } + +func clear(b []byte) { + for i := range b { + b[i] = 0 + } +} diff --git a/tests/db_test.go b/tests/db_test.go index d51da16..6b4d185 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -6,7 +6,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" - "github.com/ncruces/go-sqlite3/sqlite3vfs" + _ "github.com/ncruces/go-sqlite3/sqlite3memdb" ) func TestDB_memory(t *testing.T) { @@ -19,12 +19,8 @@ func TestDB_file(t *testing.T) { testDB(t, filepath.Join(t.TempDir(), "test.db")) } -func TestDB_VFS(t *testing.T) { - sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{ - "test.db": &sqlite3vfs.MemoryDB{}, - }) - defer sqlite3vfs.Unregister("memvfs") - testDB(t, "file:test.db?vfs=memvfs") +func TestDB_vfs(t *testing.T) { + testDB(t, "file:test.db?vfs=memdb") } func testDB(t *testing.T, name string) { diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index 21cfe4b..c9e682e 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -11,7 +11,7 @@ import ( "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" - "github.com/ncruces/go-sqlite3/sqlite3vfs" + _ "github.com/ncruces/go-sqlite3/sqlite3memdb" ) func TestParallel(t *testing.T) { @@ -40,12 +40,7 @@ func TestMemory(t *testing.T) { iter = 5000 } - sqlite3vfs.Register("memvfs", sqlite3vfs.MemoryVFS{ - "test.db": &sqlite3vfs.MemoryDB{}, - }) - defer sqlite3vfs.Unregister("memvfs") - - name := "file:test.db?vfs=memvfs" + + name := "file:/test.db?vfs=memdb" + "&_pragma=busy_timeout(10000)" + "&_pragma=locking_mode(normal)" + "&_pragma=journal_mode(memory)" + diff --git a/tests/vfs_test.go b/tests/vfs_test.go index 2be7d7e..c6c9e9b 100644 --- a/tests/vfs_test.go +++ b/tests/vfs_test.go @@ -2,27 +2,19 @@ package tests import ( "errors" - "strings" "testing" - _ "embed" - "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/driver" _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/sqlite3memdb" "github.com/ncruces/go-sqlite3/sqlite3vfs" ) -//go:embed testdata/test.db -var testdata string - func TestMemoryVFS_Open_notfound(t *testing.T) { - sqlite3vfs.Register("memory", sqlite3vfs.MemoryVFS{ - "test.db": &sqlite3vfs.MemoryDB{}, - }) - defer sqlite3vfs.Unregister("memory") + sqlite3memdb.Delete("demo.db") - _, err := sqlite3.Open("file:demo.db?vfs=memory&mode=ro") + _, err := sqlite3.Open("file:/demo.db?vfs=memdb&mode=ro") if err == nil { t.Error("want error") } @@ -31,36 +23,8 @@ func TestMemoryVFS_Open_notfound(t *testing.T) { } } -func TestMemoryVFS_Open_errors(t *testing.T) { - sqlite3vfs.Register("memory", sqlite3vfs.MemoryVFS{ - "test.db": &sqlite3vfs.MemoryDB{MaxSize: 65536}, - }) - defer sqlite3vfs.Unregister("memory") - - db, err := sqlite3.Open("file:test.db?vfs=memory") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) - if err != nil { - t.Fatal(err) - } - - err = db.Exec(`INSERT INTO test VALUES (zeroblob(65536))`) - if err == nil { - t.Error("want error") - } - if !errors.Is(err, sqlite3.FULL) { - t.Errorf("got %v, want sqlite3.FULL", err) - } -} - func TestReaderVFS_Open_notfound(t *testing.T) { - sqlite3vfs.Register("reader", sqlite3vfs.ReaderVFS{ - "test.db": sqlite3vfs.NewSizeReaderAt(strings.NewReader(testdata)), - }) + sqlite3vfs.Register("reader", sqlite3vfs.ReaderVFS{}) defer sqlite3vfs.Unregister("reader") _, err := sqlite3.Open("file:demo.db?vfs=reader&mode=ro")