Serdes robustness.

This commit is contained in:
Nuno Cruces
2025-02-12 00:41:16 +00:00
parent 9b4002f5ac
commit 30c1bcdbe9
2 changed files with 33 additions and 10 deletions

View File

@@ -8,17 +8,21 @@ import (
"github.com/ncruces/go-sqlite3/vfs"
)
const vfsName = "github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS"
func init() {
vfs.Register(vfsName, sliceVFS{})
}
var fileToOpen = make(chan *sliceFile, 1)
// Serialize backs up a database into a byte slice.
//
// https://sqlite.org/c3ref/serialize.html
func Serialize(db *sqlite3.Conn, schema string) ([]byte, error) {
var file sliceFile
fileToOpen <- &file
err := db.Backup(schema, "file:db?vfs="+vfsName)
err := db.Backup(schema, "file:serdes.db?vfs="+vfsName)
return file.data, err
}
@@ -38,21 +42,21 @@ func Serialize(db *sqlite3.Conn, schema string) ([]byte, error) {
// ["reader"]: https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs/readervfs
func Deserialize(db *sqlite3.Conn, schema string, data []byte) error {
fileToOpen <- &sliceFile{data}
return db.Restore(schema, "file:db?vfs="+vfsName)
return db.Restore(schema, "file:serdes.db?vfs="+vfsName)
}
var fileToOpen = make(chan *sliceFile, 1)
const vfsName = "github.com/ncruces/go-sqlite3/ext/deserialize.sliceVFS"
type sliceVFS struct{}
func (sliceVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
if flags&vfs.OPEN_MAIN_DB == 0 {
// notest // OPEN_MEMORY
if flags&vfs.OPEN_MAIN_DB == 0 || name != "serdes.db" {
return nil, flags, sqlite3.CANTOPEN
}
return <-fileToOpen, flags | vfs.OPEN_MEMORY, nil
select {
case file := <-fileToOpen:
return file, flags | vfs.OPEN_MEMORY, nil
default:
return nil, flags, sqlite3.MISUSE
}
}
func (sliceVFS) Delete(name string, dirSync bool) error {
@@ -61,7 +65,7 @@ func (sliceVFS) Delete(name string, dirSync bool) error {
}
func (sliceVFS) Access(name string, flag vfs.AccessFlag) (bool, error) {
return name == "db", nil
return name == "serdes.db", nil
}
func (sliceVFS) FullPathname(name string) (string, error) {

View File

@@ -1,6 +1,7 @@
package serdes_test
import (
"errors"
"io"
"net/http"
"testing"
@@ -66,3 +67,21 @@ func httpGet() ([]byte, error) {
defer res.Body.Close()
return io.ReadAll(res.Body)
}
func TestOpen_errors(t *testing.T) {
_, err := sqlite3.Open("file:test.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.CANTOPEN) {
t.Errorf("got %v, want sqlite3.CANTOPEN", err)
}
_, err = sqlite3.Open("file:serdes.db?vfs=github.com/ncruces/go-sqlite3/ext/serdes.sliceVFS")
if err == nil {
t.Error("want error")
}
if !errors.Is(err, sqlite3.MISUSE) {
t.Errorf("got %v, want sqlite3.MISUSE", err)
}
}