Checksum VFS. (#176)

This commit is contained in:
Nuno Cruces
2024-10-25 00:12:29 +01:00
committed by GitHub
parent 64e2500ca8
commit 75c1dbb052
24 changed files with 499 additions and 41 deletions

View File

@@ -521,10 +521,3 @@ func (c *Conn) stmtsIter(yield func(*Stmt) bool) {
} }
} }
} }
// DriverConn is implemented by the SQLite [database/sql] driver connection.
//
// Deprecated: use [github.com/ncruces/go-sqlite3/driver.Conn] instead.
type DriverConn interface {
Raw() *Conn
}

Binary file not shown.

Binary file not shown.

BIN
tests/testdata/wal.db vendored

Binary file not shown.

View File

@@ -77,7 +77,7 @@ func TestWAL_readonly(t *testing.T) {
// Select the data using the second (readonly) connection. // Select the data using the second (readonly) connection.
var name string var name string
err = db2.QueryRow("SELECT name FROM t").Scan(&name) err = db2.QueryRow(`SELECT name FROM t`).Scan(&name)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -95,7 +95,7 @@ func TestWAL_readonly(t *testing.T) {
} }
// Select the data using the second (readonly) connection. // Select the data using the second (readonly) connection.
err = db2.QueryRow("SELECT name FROM t").Scan(&name) err = db2.QueryRow(`SELECT name FROM t`).Scan(&name)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

9
txn.go
View File

@@ -143,7 +143,7 @@ func (c *Conn) Savepoint() Savepoint {
// Names can be reused, but this makes catching bugs more likely. // Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31()))) name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.txnExecInterrupted("SAVEPOINT " + name) err := c.txnExecInterrupted(`SAVEPOINT ` + name)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -187,7 +187,7 @@ func (s Savepoint) Release(errp *error) {
if s.c.GetAutocommit() { // There is nothing to commit. if s.c.GetAutocommit() { // There is nothing to commit.
return return
} }
*errp = s.c.Exec("RELEASE " + s.name) *errp = s.c.Exec(`RELEASE ` + s.name)
if *errp == nil { if *errp == nil {
return return
} }
@@ -199,8 +199,7 @@ func (s Savepoint) Release(errp *error) {
return return
} }
// ROLLBACK and RELEASE even if interrupted. // ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted("ROLLBACK TO " + err := s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name + `; RELEASE ` + s.name)
s.name + "; RELEASE " + s.name)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -213,7 +212,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html // https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error { func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted. // ROLLBACK even if interrupted.
return s.c.txnExecInterrupted("ROLLBACK TO " + s.name) return s.c.txnExecInterrupted(`ROLLBACK TO ` + s.name)
} }
func (c *Conn) txnExecInterrupted(sql string) error { func (c *Conn) txnExecInterrupted(sql string) error {

View File

@@ -22,6 +22,14 @@ func UnwrapFile[T vfs.File](f vfs.File) (_ T, _ bool) {
} }
} }
// WrapOpenFilename helps wrap [vfs.VFSFilename].
func WrapOpenFilename(f vfs.VFS, name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
if f, ok := f.(vfs.VFSFilename); ok {
return f.OpenFilename(name, flags)
}
return f.Open(name.String(), flags)
}
// WrapLockState helps wrap [vfs.FileLockState]. // WrapLockState helps wrap [vfs.FileLockState].
func WrapLockState(f vfs.File) vfs.LockLevel { func WrapLockState(f vfs.File) vfs.LockLevel {
if f, ok := f.(vfs.FileLockState); ok { if f, ok := f.(vfs.FileLockState); ok {

View File

@@ -21,7 +21,7 @@ var testDB string
func Test_fileformat(t *testing.T) { func Test_fileformat(t *testing.T) {
readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB))) readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB)))
adiantum.Register("radiantum", vfs.Find("reader"), nil) vfs.Register("radiantum", adiantum.Wrap(vfs.Find("reader"), nil))
db, err := driver.Open("file:test.db?vfs=radiantum") db, err := driver.Open("file:test.db?vfs=radiantum")
if err != nil { if err != nil {

View File

@@ -40,24 +40,25 @@ import (
) )
func init() { func init() {
Register("adiantum", vfs.Find(""), nil) vfs.Register("adiantum", Wrap(vfs.Find(""), nil))
} }
// Register registers an encrypting VFS, wrapping a base VFS, // Wrap wraps a base VFS to create an encrypting VFS,
// and possibly using a custom HBSH cipher construction. // possibly using a custom HBSH cipher construction.
//
// To use the default Adiantum construction, set cipher to nil. // To use the default Adiantum construction, set cipher to nil.
// //
// The default construction uses a 32 byte key/hexkey. // The default construction uses a 32 byte key/hexkey.
// If a textkey is provided, the default KDF is Argon2id // If a textkey is provided, the default KDF is Argon2id
// with 64 MiB of memory, 3 iterations, and 4 threads. // with 64 MiB of memory, 3 iterations, and 4 threads.
func Register(name string, base vfs.VFS, cipher HBSHCreator) { func Wrap(base vfs.VFS, cipher HBSHCreator) vfs.VFS {
if cipher == nil { if cipher == nil {
cipher = adiantumCreator{} cipher = adiantumCreator{}
} }
vfs.Register(name, &hbshVFS{ return &hbshVFS{
VFS: base, VFS: base,
init: cipher, init: cipher,
}) }
} }
// HBSHCreator creates an [hbsh.HBSH] cipher // HBSHCreator creates an [hbsh.HBSH] cipher

View File

@@ -17,7 +17,7 @@ import (
) )
func ExampleRegister_hpolyc() { func ExampleRegister_hpolyc() {
adiantum.Register("hpolyc", vfs.Find(""), hpolycCreator{}) vfs.Register("hpolyc", adiantum.Wrap(vfs.Find(""), hpolycCreator{}))
db, err := sqlite3.Open("file:demo.db?vfs=hpolyc" + db, err := sqlite3.Open("file:demo.db?vfs=hpolyc" +
"&textkey=correct+horse+battery+staple") "&textkey=correct+horse+battery+staple")

View File

@@ -24,11 +24,7 @@ func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag,
} }
func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
if hf, ok := h.VFS.(vfs.VFSFilename); ok { file, flags, err = vfsutil.WrapOpenFilename(h.VFS, name, flags)
file, flags, err = hf.OpenFilename(name, flags)
} else {
file, flags, err = h.VFS.Open(name.String(), flags)
}
// Encrypt everything except super journals and memory files. // Encrypt everything except super journals and memory files.
if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 { if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
@@ -49,13 +45,14 @@ func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs
} else if t, ok := params["textkey"]; ok && len(t[0]) > 0 { } else if t, ok := params["textkey"]; ok && len(t[0]) > 0 {
key = h.init.KDF(t[0]) key = h.init.KDF(t[0])
} else if flags&vfs.OPEN_MAIN_DB != 0 { } else if flags&vfs.OPEN_MAIN_DB != 0 {
// Main datatabases may have their key specified as a PRAGMA. // Main databases may have their key specified as a PRAGMA.
return &hbshFile{File: file, init: h.init}, flags, nil return &hbshFile{File: file, init: h.init}, flags, nil
} }
hbsh = h.init.HBSH(key) hbsh = h.init.HBSH(key)
} }
if hbsh == nil { if hbsh == nil {
file.Close()
return nil, flags, sqlite3.CANTOPEN return nil, flags, sqlite3.CANTOPEN
} }
return &hbshFile{File: file, hbsh: hbsh, init: h.init}, flags, nil return &hbshFile{File: file, hbsh: hbsh, init: h.init}, flags, nil

20
vfs/cksmvfs/README.md Normal file
View File

@@ -0,0 +1,20 @@
# Go `cksmvfs` SQLite VFS
This package wraps an SQLite VFS to help detect database corruption.
The `"cksmvfs"` VFS wraps the default SQLite VFS adding an 8-byte checksum
to the end of every page in an SQLite database.\
The checksum is added as each page is written
and verified as each page is read.\
The checksum is intended to help detect database corruption
caused by random bit-flips in the mass storage device.
This implementation is compatible with SQLite's
[Checksum VFS Shim](https://sqlite.org/cksumvfs.html).
> [!IMPORTANT]
> [Checksums](https://en.wikipedia.org/wiki/Checksum)
> are meant to protect against _silent data corruption_ (bit rot).
> They do not offer _authenticity_ (i.e. protect against _forgery_),
> nor prevent _silent loss of durability_.
> Checkpoint WAL mode databases to improve durabiliy.

75
vfs/cksmvfs/api.go Normal file
View File

@@ -0,0 +1,75 @@
// Package cksmvfs wraps an SQLite VFS to help detect database corruption.
//
// The "cksmvfs" [vfs.VFS] wraps the default VFS adding an 8-byte checksum
// to the end of every page in an SQLite database.
// The checksum is added as each page is written
// and verified as each page is read.
// The checksum is intended to help detect database corruption
// caused by random bit-flips in the mass storage device.
//
// This implementation is compatible with SQLite's
// [Checksum VFS Shim].
//
// [Checksum VFS Shim]: https://sqlite.org/cksumvfs.html
package cksmvfs
import (
"fmt"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/vfs"
)
func init() {
vfs.Register("cksmvfs", Wrap(vfs.Find("")))
}
// Wrap wraps a base VFS to create a checksumming VFS.
func Wrap(base vfs.VFS) vfs.VFS {
return &cksmVFS{VFS: base}
}
// EnableChecksums enables checksums on a database.
func EnableChecksums(db *sqlite3.Conn, schema string) error {
if f, ok := db.Filename("").DatabaseFile().(*cksmFile); !ok {
return fmt.Errorf("cksmvfs: incorrect type: %T", f)
}
r, err := db.FileControl(schema, sqlite3.FCNTL_RESERVE_BYTES)
if err != nil {
return err
}
if r == 8 {
// Correct value, enabled.
return nil
}
if r == 0 {
// Default value, enable.
_, err = db.FileControl(schema, sqlite3.FCNTL_RESERVE_BYTES, 8)
if err != nil {
return err
}
r, err = db.FileControl(schema, sqlite3.FCNTL_RESERVE_BYTES)
if err != nil {
return err
}
}
if r != 8 {
// Invalid value.
return fmt.Errorf("cksmvfs: reserve bytes must be 8, is: %d", r)
}
// VACUUM the database.
if schema != "" {
err = db.Exec(`VACUUM ` + sqlite3.QuoteIdentifier(schema))
} else {
err = db.Exec(`VACUUM`)
}
if err != nil {
return err
}
// Checkpoint the WAL.
_, _, err = db.WALCheckpoint(schema, sqlite3.CHECKPOINT_RESTART)
return err
}

133
vfs/cksmvfs/api_test.go Normal file
View File

@@ -0,0 +1,133 @@
package cksmvfs_test
import (
_ "embed"
"log"
"path/filepath"
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
"github.com/ncruces/go-sqlite3/util/ioutil"
"github.com/ncruces/go-sqlite3/vfs"
"github.com/ncruces/go-sqlite3/vfs/cksmvfs"
"github.com/ncruces/go-sqlite3/vfs/memdb"
"github.com/ncruces/go-sqlite3/vfs/readervfs"
)
//go:embed testdata/cksm.db
var cksmDB string
func Test_fileformat(t *testing.T) {
readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(cksmDB)))
vfs.Register("rcksm", cksmvfs.Wrap(vfs.Find("reader")))
db, err := driver.Open("file:test.db?vfs=rcksm")
if err != nil {
t.Fatal(err)
}
defer db.Close()
var enabled bool
err = db.QueryRow(`PRAGMA checksum_verification`).Scan(&enabled)
if err != nil {
t.Fatal(err)
}
if !enabled {
t.Error("want true")
}
db.SetMaxIdleConns(0) // Clears the page cache.
_, err = db.Exec(`PRAGMA integrity_check`)
if err != nil {
t.Fatal(err)
}
}
//go:embed testdata/test.db
var testDB []byte
func Test_enable(t *testing.T) {
memdb.Create("nockpt.db", testDB)
vfs.Register("mcksm", cksmvfs.Wrap(vfs.Find("memdb")))
db, err := driver.Open("file:/nockpt.db?vfs=mcksm",
func(db *sqlite3.Conn) error {
return cksmvfs.EnableChecksums(db, "")
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
var enabled bool
err = db.QueryRow(`PRAGMA checksum_verification`).Scan(&enabled)
if err != nil {
t.Fatal(err)
}
if !enabled {
t.Error("want true")
}
db.SetMaxIdleConns(0) // Clears the page cache.
_, err = db.Exec(`PRAGMA integrity_check`)
if err != nil {
t.Fatal(err)
}
}
func Test_new(t *testing.T) {
if !vfs.SupportsFileLocking {
t.Skip("skipping without locks")
}
name := "file:" +
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db")) +
"?vfs=cksmvfs&_pragma=journal_mode(wal)"
db, err := driver.Open(name)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var enabled bool
err = db.QueryRow(`PRAGMA checksum_verification`).Scan(&enabled)
if err != nil {
t.Fatal(err)
}
if !enabled {
t.Error("want true")
}
var size int
err = db.QueryRow(`PRAGMA page_size=1024`).Scan(&size)
if err != nil {
t.Fatal(err)
}
if size != 4096 {
t.Errorf("got %d, want 4096", size)
}
_, err = db.Exec(`CREATE TABLE users (id INT, name VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}
_, err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
log.Fatal(err)
}
db.SetMaxIdleConns(0) // Clears the page cache.
_, err = db.Exec(`PRAGMA integrity_check`)
if err != nil {
t.Fatal(err)
}
}

234
vfs/cksmvfs/cksmvfs.go Normal file
View File

@@ -0,0 +1,234 @@
package cksmvfs
import (
"bytes"
_ "embed"
"encoding/binary"
"io"
"runtime"
"strconv"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/util/sql3util"
"github.com/ncruces/go-sqlite3/util/vfsutil"
"github.com/ncruces/go-sqlite3/vfs"
)
type cksmVFS struct {
vfs.VFS
}
func (c *cksmVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
// notest // OpenFilename is called instead
return nil, 0, sqlite3.CANTOPEN
}
func (c *cksmVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
// Prevent accidental wrapping.
if pc, _, _, ok := runtime.Caller(1); ok {
if fn := runtime.FuncForPC(pc); fn != nil {
if fn.Name() != "github.com/ncruces/go-sqlite3/vfs.vfsOpen" {
return nil, 0, sqlite3.CANTOPEN
}
}
}
file, flags, err = vfsutil.WrapOpenFilename(c.VFS, name, flags)
// Checksum only main databases and WALs.
if err != nil || flags&(vfs.OPEN_MAIN_DB|vfs.OPEN_WAL) == 0 {
return file, flags, err
}
cksm := cksmFile{File: file}
if flags&vfs.OPEN_WAL != 0 {
main, _ := name.DatabaseFile().(*cksmFile)
cksm.cksmFlags = main.cksmFlags
} else {
cksm.isDB = true
cksm.cksmFlags = new(cksmFlags)
}
const createDB = vfs.OPEN_CREATE | vfs.OPEN_READWRITE | vfs.OPEN_MAIN_DB
cksm.createDB = flags&createDB == createDB
return &cksm, flags, err
}
type cksmFile struct {
vfs.File
*cksmFlags
isDB bool
createDB bool
}
type cksmFlags struct {
computeCksm bool
verifyCksm bool
inCkpt bool
pageSize int
}
//go:embed empty.db
var empty string
func (c *cksmFile) ReadAt(p []byte, off int64) (n int, err error) {
n, err = c.File.ReadAt(p, off)
// SQLite is trying to read from the first page of an empty database file.
// Instead, read from an empty database that had checksums enabled,
// so checksums are enabled by default.
if c.createDB && n == 0 && err == io.EOF && off < 100 {
n = copy(p, empty[off:])
if n < len(p) {
clear(p[n:])
}
err = nil
}
// SQLite is reading the header of a database file.
if c.isDB && off == 0 && len(p) >= 100 &&
bytes.HasPrefix(p, []byte("SQLite format 3\000")) {
c.updateFlags(p)
}
// Verify checksums.
if c.verifyCksm && !c.inCkpt && len(p) == c.pageSize {
cksm1 := cksmCompute(p[:len(p)-8])
cksm2 := *(*[8]byte)(p[len(p)-8:])
if cksm1 != cksm2 {
return 0, sqlite3.IOERR_DATA
}
}
return n, err
}
func (c *cksmFile) WriteAt(p []byte, off int64) (n int, err error) {
// SQLite is writing the first page of a database file.
if c.isDB && off == 0 && len(p) >= 100 &&
bytes.HasPrefix(p, []byte("SQLite format 3\000")) {
c.updateFlags(p)
}
// Compute checksums.
if c.computeCksm && !c.inCkpt && len(p) == c.pageSize {
*(*[8]byte)(p[len(p)-8:]) = cksmCompute(p[:len(p)-8])
}
return c.File.WriteAt(p, off)
}
func (c *cksmFile) updateFlags(header []byte) {
c.pageSize = 256 * int(binary.LittleEndian.Uint16(header[16:18]))
if r := header[20] == 8; r != c.computeCksm {
c.computeCksm = r
c.verifyCksm = r
}
}
func (c *cksmFile) CheckpointStart() {
c.inCkpt = true
}
func (c *cksmFile) CheckpointDone() {
c.inCkpt = false
}
func (c *cksmFile) Pragma(name string, value string) (string, error) {
switch name {
case "checksum_verification":
b, ok := sql3util.ParseBool(value)
if ok {
c.verifyCksm = b && c.computeCksm
}
if !c.verifyCksm {
return "0", nil
}
return "1", nil
case "page_size":
if c.computeCksm {
// Do not allow page size changes on a checksum database.
return strconv.Itoa(c.pageSize), nil
}
}
return vfsutil.WrapPragma(c.File, name, value)
}
func cksmCompute(a []byte) (cksm [8]byte) {
var s1, s2 uint32
for len(a) >= 8 {
s1 += binary.LittleEndian.Uint32(a[0:4]) + s2
s2 += binary.LittleEndian.Uint32(a[4:8]) + s1
a = a[8:]
}
if len(a) != 0 {
panic(util.AssertErr())
}
binary.LittleEndian.PutUint32(cksm[0:4], s1)
binary.LittleEndian.PutUint32(cksm[4:8], s2)
return
}
func (c *cksmFile) Unwrap() vfs.File {
return c.File
}
func (c *cksmFile) SharedMemory() vfs.SharedMemory {
return vfsutil.WrapSharedMemory(c.File)
}
// Wrap optional methods.
func (c *cksmFile) LockState() vfs.LockLevel {
return vfsutil.WrapLockState(c.File) // notest
}
func (c *cksmFile) PersistentWAL() bool {
return vfsutil.WrapPersistentWAL(c.File) // notest
}
func (c *cksmFile) SetPersistentWAL(keepWAL bool) {
vfsutil.WrapSetPersistentWAL(c.File, keepWAL) // notest
}
func (c *cksmFile) PowersafeOverwrite() bool {
return vfsutil.WrapPowersafeOverwrite(c.File) // notest
}
func (c *cksmFile) SetPowersafeOverwrite(psow bool) {
vfsutil.WrapSetPowersafeOverwrite(c.File, psow) // notest
}
func (c *cksmFile) ChunkSize(size int) {
vfsutil.WrapChunkSize(c.File, size) // notest
}
func (c *cksmFile) SizeHint(size int64) error {
return vfsutil.WrapSizeHint(c.File, size) // notest
}
func (c *cksmFile) HasMoved() (bool, error) {
return vfsutil.WrapHasMoved(c.File) // notest
}
func (c *cksmFile) Overwrite() error {
return vfsutil.WrapOverwrite(c.File) // notest
}
func (c *cksmFile) CommitPhaseTwo() error {
return vfsutil.WrapCommitPhaseTwo(c.File) // notest
}
func (c *cksmFile) BeginAtomicWrite() error {
return vfsutil.WrapBeginAtomicWrite(c.File) // notest
}
func (c *cksmFile) CommitAtomicWrite() error {
return vfsutil.WrapCommitAtomicWrite(c.File) // notest
}
func (c *cksmFile) RollbackAtomicWrite() error {
return vfsutil.WrapRollbackAtomicWrite(c.File) // notest
}

BIN
vfs/cksmvfs/empty.db Normal file

Binary file not shown.

BIN
vfs/cksmvfs/testdata/cksm.db vendored Normal file

Binary file not shown.

BIN
vfs/cksmvfs/testdata/test.db vendored Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -21,7 +21,7 @@ var testDB string
func Test_fileformat(t *testing.T) { func Test_fileformat(t *testing.T) {
readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB))) readervfs.Create("test.db", ioutil.NewSizeReaderAt(strings.NewReader(testDB)))
xts.Register("rxts", vfs.Find("reader"), nil) vfs.Register("rxts", xts.Wrap(vfs.Find("reader"), nil))
db, err := driver.Open("file:test.db?vfs=rxts") db, err := driver.Open("file:test.db?vfs=rxts")
if err != nil { if err != nil {

View File

@@ -40,25 +40,26 @@ import (
) )
func init() { func init() {
Register("xts", vfs.Find(""), nil) vfs.Register("xts", Wrap(vfs.Find(""), nil))
} }
// Register registers an encrypting VFS, wrapping a base VFS, // Wrap wraps a base VFS to create an encrypting VFS,
// and possibly using a custom XTS cipher construction. // possibly using a custom XTS cipher construction.
//
// To use the default AES-XTS construction, set cipher to nil. // To use the default AES-XTS construction, set cipher to nil.
// //
// The default construction uses AES-128, AES-192, or AES-256 // The default construction uses AES-128, AES-192, or AES-256
// if the key/hexkey is 32, 48, or 64 bytes, respectively. // if the key/hexkey is 32, 48, or 64 bytes, respectively.
// If a textkey is provided, the default KDF is PBKDF2-HMAC-SHA512 // If a textkey is provided, the default KDF is PBKDF2-HMAC-SHA512
// with 10,000 iterations, always producing a 32 byte key. // with 10,000 iterations, always producing a 32 byte key.
func Register(name string, base vfs.VFS, cipher XTSCreator) { func Wrap(base vfs.VFS, cipher XTSCreator) vfs.VFS {
if cipher == nil { if cipher == nil {
cipher = aesCreator{} cipher = aesCreator{}
} }
vfs.Register(name, &xtsVFS{ return &xtsVFS{
VFS: base, VFS: base,
init: cipher, init: cipher,
}) }
} }
// XTSCreator creates an [xts.Cipher] // XTSCreator creates an [xts.Cipher]

View File

@@ -23,11 +23,7 @@ func (x *xtsVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag,
} }
func (x *xtsVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { func (x *xtsVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
if hf, ok := x.VFS.(vfs.VFSFilename); ok { file, flags, err = vfsutil.WrapOpenFilename(x.VFS, name, flags)
file, flags, err = hf.OpenFilename(name, flags)
} else {
file, flags, err = x.VFS.Open(name.String(), flags)
}
// Encrypt everything except super journals and memory files. // Encrypt everything except super journals and memory files.
if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 { if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
@@ -48,13 +44,14 @@ func (x *xtsVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.
} else if t, ok := params["textkey"]; ok && len(t[0]) > 0 { } else if t, ok := params["textkey"]; ok && len(t[0]) > 0 {
key = x.init.KDF(t[0]) key = x.init.KDF(t[0])
} else if flags&vfs.OPEN_MAIN_DB != 0 { } else if flags&vfs.OPEN_MAIN_DB != 0 {
// Main datatabases may have their key specified as a PRAGMA. // Main databases may have their key specified as a PRAGMA.
return &xtsFile{File: file, init: x.init}, flags, nil return &xtsFile{File: file, init: x.init}, flags, nil
} }
cipher = x.init.XTS(key) cipher = x.init.XTS(key)
} }
if cipher == nil { if cipher == nil {
file.Close()
return nil, flags, sqlite3.CANTOPEN return nil, flags, sqlite3.CANTOPEN
} }
return &xtsFile{File: file, cipher: cipher, init: x.init}, flags, nil return &xtsFile{File: file, cipher: cipher, init: x.init}, flags, nil