fix(v0): restores validateSig behavior

This commit is contained in:
Steve Moyer
2025-07-08 14:13:47 -04:00
committed by Michael Muré
parent 03770e0d38
commit f6b72f1907
5 changed files with 26 additions and 34 deletions

View File

@@ -79,11 +79,7 @@ func NewEdDSAVarsig(curve EdDSACurve, hashAlgorithm HashAlgorithm, payloadEncodi
hashAlg: hashAlgorithm,
}
err := v.validateSig(ed25519.SignatureSize)
if err != nil {
return EdDSAVarsig{}, err
}
return v, nil
return validateSig(v, ed25519.SignatureSize)
}
// Curve returns the Edwards curve used to generate the EdDSA signature.
@@ -142,10 +138,5 @@ func decodeEd25519(r BytesReader, vers Version, disc Discriminator) (Varsig, err
return nil, err
}
err = v.validateSig(ed25519.SignatureSize)
if err != nil {
return RSAVarsig{}, err
}
return v, nil
return validateSig(v, ed25519.SignatureSize)
}

14
rsa.go
View File

@@ -45,12 +45,7 @@ func NewRSAVarsig(hashAlgorithm HashAlgorithm, keyLength uint64, payloadEncoding
sigLen: keyLength,
}
err := v.validateSig(v.sigLen)
if err != nil {
return RSAVarsig{}, err
}
return v, nil
return validateSig(v, v.sigLen)
}
// Encode returns the encoded byte format of the RSAVarsig.
@@ -100,10 +95,5 @@ func decodeRSA(r BytesReader, vers Version, disc Discriminator) (Varsig, error)
return nil, err
}
err = vs.validateSig(vs.sigLen)
if err != nil {
return RSAVarsig{}, err
}
return vs, nil
return validateSig(vs, vs.sigLen)
}

View File

@@ -69,7 +69,7 @@ func TestUCANExample(t *testing.T) {
vs, err := varsig.Decode(example)
require.ErrorIs(t, err, varsig.ErrMissingSignature)
rsaVs, ok := vs.(*varsig.RSAVarsig)
rsaVs, ok := vs.(varsig.RSAVarsig)
require.True(t, ok)
assert.Equal(t, varsig.Version0, rsaVs.Version())

View File

@@ -20,6 +20,7 @@ package varsig
import (
"encoding/binary"
"errors"
"io"
)
@@ -100,30 +101,40 @@ func (v varsig) decodePayEncAndSig(r BytesReader) (PayloadEncoding, []byte, erro
}
var signature []byte
if v.Version() == Version0 {
switch v.Version() {
case Version0:
signature, err = io.ReadAll(r)
if err != nil {
return 0, nil, err
}
case Version1:
_, err := r.ReadByte()
if err != nil && !errors.Is(err, io.EOF) {
return 0, nil, err
}
if err == nil {
return 0, nil, ErrUnexpectedSignaturePresent
}
}
return payEnc, signature, nil
}
func (v varsig) validateSig(expectedLength uint64) error {
if v.Version() == Version0 && len(v.sig) == 0 {
return ErrMissingSignature
func validateSig[T Varsig](v T, expectedLength uint64) (T, error) {
if v.Version() == Version0 && len(v.Signature()) == 0 {
return v, ErrMissingSignature
}
if v.Version() == Version0 && uint64(len(v.sig)) != expectedLength {
return ErrUnexpectedSignatureSize
if v.Version() == Version0 && uint64(len(v.Signature())) != expectedLength {
return *new(T), ErrUnexpectedSignatureSize
}
if v.Version() == Version1 && len(v.sig) != 0 {
return ErrUnexpectedSignaturePresent
if v.Version() == Version1 && len(v.Signature()) != 0 {
return *new(T), ErrUnexpectedSignaturePresent
}
return nil
return v, nil
}
type BytesReader interface {

View File

@@ -133,7 +133,7 @@ func TestDecode(t *testing.T) {
vs, err := varsig.Decode(data)
require.ErrorIs(t, err, varsig.ErrUnexpectedSignatureSize)
assert.Nil(t, vs)
assert.Zero(t, vs)
})
t.Run("fails - unexpected signature present - v1", func(t *testing.T) {