From 21a78a9d2d3e64402e51af58a18aff1f1a7c7ccf Mon Sep 17 00:00:00 2001 From: Steve Moyer Date: Tue, 8 Jul 2025 11:27:18 -0400 Subject: [PATCH] fix(eddsa): use DecodeHashAlgorithm and create decodeEdDSACurve --- eddsa.go | 32 +++++++++++++++++++++++--------- error.go | 4 ++++ rsa.go | 2 +- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/eddsa.go b/eddsa.go index fc28a2d..250f823 100644 --- a/eddsa.go +++ b/eddsa.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/ed25519" "encoding/binary" + "fmt" "github.com/multiformats/go-multicodec" ) @@ -26,6 +27,20 @@ const ( CurveEd448 = EdDSACurve(multicodec.Ed448Pub) ) +func decodeEdDSACurve(r *bytes.Reader) (EdDSACurve, error) { + u, err := binary.ReadUvarint(r) + if err != nil { + return 0, err + } + + switch curve := EdDSACurve(u); curve { + case CurveEd25519, CurveEd448: + return curve, nil + default: + return 0, fmt.Errorf("%w: %x", ErrUnknownEdDSACurve, u) + } +} + var _ Varsig = (*EdDSAVarsig)(nil) // EdDSAVarsig is a varsig that encodes the parameters required to describe @@ -95,20 +110,19 @@ func (v EdDSAVarsig) Encode() []byte { } func decodeEd25519(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, error) { - curve := uint64(disc) + curve := EdDSACurve(disc) if vers != Version0 { - u, err := binary.ReadUvarint(r) + var err error + curve, err = decodeEdDSACurve(r) if err != nil { - return nil, err // TODO: wrap error? + return nil, err } - - curve = u } - hashAlg, err := binary.ReadUvarint(r) + hashAlg, err := DecodeHashAlgorithm(r) if err != nil { - return nil, err // TODO: wrap error? + return nil, err } v := &EdDSAVarsig{ @@ -116,8 +130,8 @@ func decodeEd25519(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, e vers: vers, disc: disc, }, - curve: EdDSACurve(curve), - hashAlg: HashAlgorithm(hashAlg), + curve: curve, + hashAlg: hashAlg, } return v.decodePayEncAndSig(r, v, ed25519.PrivateKeySize) diff --git a/error.go b/error.go index a5e6edb..908ed62 100644 --- a/error.go +++ b/error.go @@ -34,6 +34,10 @@ var ErrUnsupportedPayloadEncoding = errors.New("unsupported payload encoding") // parsing function for the decoded signing algorithm. var ErrUnknownDiscriminator = errors.New("unknown signing algorithm") +// ErrUnknownEdDSACurve is returned when the decoded uvarint isn't either +// CurveEd25519 or CurveEd448. +var ErrUnknownEdDSACurve = errors.New("unknown Edwards curve") + // ErrUnsupportedVersion is returned when an unsupported varsig version // field is present. var ErrUnsupportedVersion = errors.New("unsupported version") diff --git a/rsa.go b/rsa.go index 905c206..0aa486f 100644 --- a/rsa.go +++ b/rsa.go @@ -87,7 +87,7 @@ func decodeRSA(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, error vers: vers, disc: disc, }, - hashAlg: HashAlgorithm(hashAlg), + hashAlg: hashAlg, sigLen: sigLen, }