diff --git a/vfs/adiantum/adiantum.go b/vfs/adiantum/adiantum.go index 13d1bdd..39e45bb 100644 --- a/vfs/adiantum/adiantum.go +++ b/vfs/adiantum/adiantum.go @@ -13,6 +13,9 @@ const pepper = "github.com/ncruces/go-sqlite3/vfs/adiantum" type adiantumCreator struct{} func (adiantumCreator) HBSH(key []byte) *hbsh.HBSH { + if len(key) != 32 { + return nil + } return adiantum.New(key) } diff --git a/vfs/adiantum/hbsh.go b/vfs/adiantum/hbsh.go index 8d40d33..6b40e41 100644 --- a/vfs/adiantum/hbsh.go +++ b/vfs/adiantum/hbsh.go @@ -22,29 +22,41 @@ func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, } func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values) (file vfs.File, _ vfs.OpenFlag, err error) { + encrypt := flags&(0| + vfs.OPEN_MAIN_DB| + vfs.OPEN_MAIN_JOURNAL| + vfs.OPEN_SUBJOURNAL| + vfs.OPEN_WAL) != 0 + + var hbsh *hbsh.HBSH + if encrypt { + var key []byte + if t, ok := params["key"]; ok { + key = []byte(t[0]) + } else if t, ok := params["hexkey"]; ok { + key, _ = hex.DecodeString(t[0]) + } else if t, ok := params["textkey"]; ok { + key = h.hbsh.KDF(t[0]) + } + + if hbsh = h.hbsh.HBSH(key); hbsh == nil { + return nil, flags, sqlite3.NOTADB + } + } + if h, ok := h.VFS.(vfs.VFSParams); ok { + delete(params, "vfs") + delete(params, "key") + delete(params, "hexkey") + delete(params, "textkey") file, flags, err = h.OpenParams(name, flags, params) } else { file, flags, err = h.Open(name, flags) } - if err != nil || flags&(0| - vfs.OPEN_MAIN_DB| - vfs.OPEN_MAIN_JOURNAL| - vfs.OPEN_SUBJOURNAL| - vfs.OPEN_WAL) == 0 { + if err != nil || hbsh == nil { return file, flags, err } - - var key []byte - if t, ok := params["key"]; ok { - key = []byte(t[0]) - } else if t, ok := params["hexkey"]; ok { - key, err = hex.DecodeString(t[0]) - } else if t, ok := params["textkey"]; ok { - key = h.hbsh.KDF(t[0]) - } - - return &hbshFile{File: file, hbsh: h.hbsh.HBSH(key)}, flags, err + return &hbshFile{File: file, hbsh: hbsh}, flags, err } const ( @@ -93,20 +105,22 @@ func (h *hbshFile) WriteAt(p []byte, off int64) (n int, err error) { binary.LittleEndian.PutUint64(h.tweak[:], uint64(min)) data := h.block[:] - if min < off || len(p[n:]) < blockSize { + if off > min || len(p[n:]) < blockSize { // Read full block. m, err := h.File.ReadAt(h.block[:], min) - switch { - case m == 0 && err == io.EOF: - clear(data) - case m != blockSize: - return n, err - default: - // Partial update. - data = h.hbsh.Decrypt(h.block[:], h.tweak[:]) - if off > min { - data = data[off-min:] + if m != blockSize { + if err != io.EOF { + return n, err } + // Writing past the EOF. + // A partially written block is corrupt, + // and also considered to be past the EOF. + clear(data) + } + + data = h.hbsh.Decrypt(h.block[:], h.tweak[:]) + if off > min { + data = data[off-min:] } }