diff --git a/tests/parallel/parallel_test.go b/tests/parallel/parallel_test.go index e42c13a..40675c5 100644 --- a/tests/parallel/parallel_test.go +++ b/tests/parallel/parallel_test.go @@ -75,7 +75,7 @@ func Test_adiantum(t *testing.T) { name := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "test.db")) + "?vfs=adiantum" + - "&hexkey=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + "&_pragma=hexkey(e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855)" testParallel(t, name, iter) testIntegrity(t, name) } diff --git a/vfs/adiantum/hbsh.go b/vfs/adiantum/hbsh.go index 082999a..1072811 100644 --- a/vfs/adiantum/hbsh.go +++ b/vfs/adiantum/hbsh.go @@ -21,40 +21,33 @@ func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, } func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) { - var hbsh *hbsh.HBSH - - // Encrypt everything except super journals. - if flags&vfs.OPEN_SUPER_JOURNAL == 0 { - if f, ok := name.DatabaseFile().(*hbshFile); ok { - hbsh = f.hbsh - } else { - var key []byte - if params := name.URIParameters(); name == nil { - key = h.hbsh.KDF("") // Temporary files get a random key. - } else if t, ok := params["key"]; ok { - key = []byte(t[0]) - } else if t, ok := params["hexkey"]; ok { - key, _ = hex.DecodeString(t[0]) - } else if t, ok := params["textkey"]; ok { - key = h.hbsh.KDF(t[0]) - } - if hbsh = h.hbsh.HBSH(key); hbsh == nil { - // Can't open without a valid key. - return nil, flags, sqlite3.CANTOPEN - } - } - } - if h, ok := h.VFS.(vfs.VFSFilename); ok { file, flags, err = h.OpenFilename(name, flags) } else { file, flags, err = h.Open(name.String(), flags) } - if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 { - // Error, or no encryption (super journals, memory files). + // Encrypt everything except super journals and memory files. + if err != nil || flags&(vfs.OPEN_SUPER_JOURNAL|vfs.OPEN_MEMORY) != 0 { return file, flags, err } - return &hbshFile{File: file, hbsh: hbsh}, flags, err + + var hbsh *hbsh.HBSH + if f, ok := name.DatabaseFile().(*hbshFile); ok { + hbsh = f.hbsh + } else { + var key []byte + if params := name.URIParameters(); name == nil { + key = h.hbsh.KDF("") // Temporary files get a random key. + } else if t, ok := params["key"]; ok { + key = []byte(t[0]) + } else if t, ok := params["hexkey"]; ok { + key, _ = hex.DecodeString(t[0]) + } else if t, ok := params["textkey"]; ok { + key = h.hbsh.KDF(t[0]) + } + hbsh = h.hbsh.HBSH(key) + } + return &hbshFile{File: file, hbsh: hbsh, reset: h.hbsh}, flags, err } const ( @@ -65,11 +58,43 @@ const ( type hbshFile struct { vfs.File hbsh *hbsh.HBSH + reset HBSHCreator block [blockSize]byte tweak [tweakSize]byte } +func (h *hbshFile) Pragma(name string, value string) (string, error) { + var key []byte + switch name { + case "key": + key = []byte(value) + case "hexkey": + key, _ = hex.DecodeString(value) + case "textkey": + key = h.reset.KDF(value) + default: + if f, ok := h.File.(vfs.FilePragma); ok { + return f.Pragma(name, value) + } + return "", sqlite3.NOTFOUND + } + + if h.hbsh = h.reset.HBSH(key); h.hbsh != nil { + return "ok", nil + } + return "", sqlite3.CANTOPEN +} + func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) { + if h.hbsh == nil { + // If it's trying to read the header, pretend the file is empty, + // so the key can be specified later. + if off == 0 && len(p) == 100 { + return 0, io.EOF + } + return 0, sqlite3.CANTOPEN + } + min := (off) &^ (blockSize - 1) // round down max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up @@ -96,6 +121,10 @@ func (h *hbshFile) ReadAt(p []byte, off int64) (n int, err error) { } func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) { + if h.hbsh == nil { + return 0, sqlite3.READONLY + } + min := (off) &^ (blockSize - 1) // round down max := (off + int64(len(p)) + blockSize - 1) &^ (blockSize - 1) // round up diff --git a/vfs/vfs.go b/vfs/vfs.go index 503a35c..805716e 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -347,12 +347,17 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl case _FCNTL_PRAGMA: if file, ok := file.(FilePragma); ok { - name := util.ReadUint32(mod, pArg+1*ptrlen) - value := util.ReadUint32(mod, pArg+2*ptrlen) - out, err := file.Pragma( - util.ReadString(mod, name, _MAX_SQL_LENGTH), - util.ReadString(mod, value, _MAX_SQL_LENGTH)) - if err != nil { + ptr := util.ReadUint32(mod, pArg+1*ptrlen) + name := util.ReadString(mod, ptr, _MAX_SQL_LENGTH) + var value string + if ptr := util.ReadUint32(mod, pArg+2*ptrlen); ptr != 0 { + value = util.ReadString(mod, ptr, _MAX_SQL_LENGTH) + } + + out, err := file.Pragma(name, value) + + ret := vfsErrorCode(err, _ERROR) + if ret == _ERROR { out = err.Error() } if out != "" { @@ -363,9 +368,8 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl } util.WriteUint32(mod, pArg, uint32(stack[0])) util.WriteString(mod, uint32(stack[0]), out) - return _ERROR } - return vfsErrorCode(err, _ERROR) + return ret } }