Adiantum pragmas.

This commit is contained in:
Nuno Cruces
2024-04-27 12:19:46 +01:00
parent 3c21784aee
commit 811e6e63be
3 changed files with 69 additions and 36 deletions

View File

@@ -75,7 +75,7 @@ func Test_adiantum(t *testing.T) {
name := "file:" +
filepath.ToSlash(filepath.Join(t.TempDir(), "test.db")) +
"?vfs=adiantum" +
"&hexkey=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
"&_pragma=hexkey(e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855)"
testParallel(t, name, iter)
testIntegrity(t, name)
}

View File

@@ -21,40 +21,33 @@ func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag,
}
func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
var hbsh *hbsh.HBSH
// Encrypt everything except super journals.
if flags&vfs.OPEN_SUPER_JOURNAL == 0 {
if f, ok := name.DatabaseFile().(*hbshFile); ok {
hbsh = f.hbsh
} else {
var key []byte
if params := name.URIParameters(); name == nil {
key = h.hbsh.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok {
key = h.hbsh.KDF(t[0])
}
if hbsh = h.hbsh.HBSH(key); hbsh == nil {
// Can't open without a valid key.
return nil, flags, sqlite3.CANTOPEN
}
}
}
if h, ok := h.VFS.(vfs.VFSFilename); ok {
file, flags, err = h.OpenFilename(name, flags)
} else {
file, flags, err = h.Open(name.String(), flags)
}
if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 {
// Error, or no encryption (super journals, memory files).
// Encrypt everything except super journals and memory files.
if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 {
return file, flags, err
}
return &hbshFile{File: file, hbsh: hbsh}, flags, err
var hbsh *hbsh.HBSH
if f, ok := name.DatabaseFile().(*hbshFile); ok {
hbsh = f.hbsh
} else {
var key []byte
if params := name.URIParameters(); name == nil {
key = h.hbsh.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok {
key = h.hbsh.KDF(t[0])
}
hbsh = h.hbsh.HBSH(key)
}
return &hbshFile{File: file, hbsh: hbsh, reset: h.hbsh}, flags, err
}
const (
@@ -65,11 +58,43 @@ const (
type hbshFile struct {
vfs.File
hbsh *hbsh.HBSH
reset HBSHCreator
block [blockSize]byte
tweak [tweakSize]byte
}
func (h *hbshFile) Pragma(name string, value string) (string, error) {
var key []byte
switch name {
case "key":
key = []byte(value)
case "hexkey":
key, _ = hex.DecodeString(value)
case "textkey":
key = h.reset.KDF(value)
default:
if f, ok := h.File.(vfs.FilePragma); ok {
return f.Pragma(name, value)
}
return "", sqlite3.NOTFOUND
}
if h.hbsh = h.reset.HBSH(key); h.hbsh != nil {
return "ok", nil
}
return "", sqlite3.CANTOPEN
}
func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
if h.hbsh == nil {
// If it's trying to read the header, pretend the file is empty,
// so the key can be specified later.
if off == 0 && len(p) == 100 {
return 0, io.EOF
}
return 0, sqlite3.CANTOPEN
}
min := (off) &^ (blockSize - 1) // round down
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up
@@ -96,6 +121,10 @@ func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) {
}
func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) {
if h.hbsh == nil {
return 0, sqlite3.READONLY
}
min := (off) &^ (blockSize - 1) // round down
max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up

View File

@@ -347,12 +347,17 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
case _FCNTL_PRAGMA:
if file, ok := file.(FilePragma); ok {
name := util.ReadUint32(mod, pArg+1*ptrlen)
value := util.ReadUint32(mod, pArg+2*ptrlen)
out, err := file.Pragma(
util.ReadString(mod, name, _MAX_SQL_LENGTH),
util.ReadString(mod, value, _MAX_SQL_LENGTH))
if err != nil {
ptr := util.ReadUint32(mod, pArg+1*ptrlen)
name := util.ReadString(mod, ptr, _MAX_SQL_LENGTH)
var value string
if ptr := util.ReadUint32(mod, pArg+2*ptrlen); ptr != 0 {
value = util.ReadString(mod, ptr, _MAX_SQL_LENGTH)
}
out, err := file.Pragma(name, value)
ret := vfsErrorCode(err, _ERROR)
if ret == _ERROR {
out = err.Error()
}
if out != "" {
@@ -363,9 +368,8 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
}
util.WriteUint32(mod, pArg, uint32(stack[0]))
util.WriteString(mod, uint32(stack[0]), out)
return _ERROR
}
return vfsErrorCode(err, _ERROR)
return ret
}
}