From 03770e0d38541db8a7db9f4dcd7d146434c88984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Mur=C3=A9?= Date: Tue, 8 Jul 2025 18:38:23 +0200 Subject: [PATCH] use value receiver, remove unneeded generic --- README.md | 6 +++--- common.go | 10 +++++----- common_test.go | 19 +++++++++++------- constant.go | 5 ++--- eddsa.go | 43 +++++++++++++++++++++++++-------------- eddsa_test.go | 2 +- error.go | 8 ++++---- registry.go | 14 ++++++------- registry_test.go | 14 ++++++------- rsa.go | 42 +++++++++++++++++++++++++------------- rsa_test.go | 2 +- varsig.go | 52 +++++++++++++++++++++++++----------------------- varsig_test.go | 10 +++------- 13 files changed, 128 insertions(+), 99 deletions(-) diff --git a/README.md b/README.md index 93b81f6..0a184e6 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,12 @@ `go-varsig` implements the upcoming v1.0.0 release of the [`varsig` specification](https://github.com/ChainAgnostic/varsig/pull/18) with limited (and soon to be deprecated) support for the `varsig` < v1.0 -specification. This is predominatly included to support the UCAN v1.0 +specification. This is predominantly included to support the UCAN v1.0 use-case. ## Usage -Include the `go-varsig` library by running the following command: +Include the `go-varsig` library by running the following command: ```bash go get github.com/ucan-wg/go-varsig@latest @@ -29,7 +29,7 @@ asdf install ### Checks -This repository contains an set of pre-commit hooks that are run prior to +This repository contains a set of pre-commit hooks that are run prior to each `git commit`. You can also run these checks manually using the following command: diff --git a/common.go b/common.go index 4eecb73..bcee856 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,7 @@ package varsig // by the [IANA JOSE specification]. // // [IANA JOSE specification]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms -func Ed25519(payloadEncoding PayloadEncoding, opts ...Option) (*EdDSAVarsig, error) { +func Ed25519(payloadEncoding PayloadEncoding, opts ...Option) (EdDSAVarsig, error) { return NewEdDSAVarsig(CurveEd25519, HashAlgorithmSHA512, payloadEncoding, opts...) } @@ -12,7 +12,7 @@ func Ed25519(payloadEncoding PayloadEncoding, opts ...Option) (*EdDSAVarsig, err // by the [IANA JOSE specification]. // // [IANA JOSE specification]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms -func Ed448(payloadEncoding PayloadEncoding, opts ...Option) (*EdDSAVarsig, error) { +func Ed448(payloadEncoding PayloadEncoding, opts ...Option) (EdDSAVarsig, error) { return NewEdDSAVarsig(CurveEd448, HashAlgorithmShake256, payloadEncoding, opts...) } @@ -20,7 +20,7 @@ func Ed448(payloadEncoding PayloadEncoding, opts ...Option) (*EdDSAVarsig, error // by the [IANA JOSE specification]. // // [IANA JOSE specification]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms -func RS256(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { +func RS256(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (RSAVarsig, error) { return NewRSAVarsig(HashAlgorithmSHA256, keyLength, payloadEncoding, opts...) } @@ -28,7 +28,7 @@ func RS256(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (* // by the [IANA JOSE specification]. // // [IANA JOSE specification]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms -func RS384(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { +func RS384(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (RSAVarsig, error) { return NewRSAVarsig(HashAlgorithmSHA384, keyLength, payloadEncoding, opts...) } @@ -36,6 +36,6 @@ func RS384(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (* // by the [IANA JOSE specification]. // // [IANA JOSE specification]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms -func RS512(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { +func RS512(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (RSAVarsig, error) { return NewRSAVarsig(HashAlgorithmSHA512, keyLength, payloadEncoding, opts...) } diff --git a/common_test.go b/common_test.go index 5b8f3ee..81cefdd 100644 --- a/common_test.go +++ b/common_test.go @@ -11,7 +11,8 @@ import ( func TestEd25519(t *testing.T) { t.Parallel() - in := mustVarsig[varsig.EdDSAVarsig](t)(varsig.Ed25519(varsig.PayloadEncodingDAGCBOR)) + in, err := varsig.Ed25519(varsig.PayloadEncodingDAGCBOR) + mustVarsig(t, in, err) out := roundTrip(t, in, "3401ed01ed011371") assertEdDSAEqual(t, in, out) } @@ -19,7 +20,8 @@ func TestEd25519(t *testing.T) { func TestEd448(t *testing.T) { t.Parallel() - in := mustVarsig[varsig.EdDSAVarsig](t)(varsig.Ed448(varsig.PayloadEncodingDAGCBOR)) + in, err := varsig.Ed448(varsig.PayloadEncodingDAGCBOR) + mustVarsig(t, in, err) out := roundTrip(t, in, "3401ed0183241971") assertEdDSAEqual(t, in, out) } @@ -27,7 +29,8 @@ func TestEd448(t *testing.T) { func TestRS256(t *testing.T) { t.Parallel() - in := mustVarsig[varsig.RSAVarsig](t)(varsig.RS256(0x100, varsig.PayloadEncodingDAGCBOR)) + in, err := varsig.RS256(0x100, varsig.PayloadEncodingDAGCBOR) + mustVarsig(t, in, err) out := roundTrip(t, in, "3401852412800271") assertRSAEqual(t, in, out) } @@ -35,7 +38,8 @@ func TestRS256(t *testing.T) { func TestRS384(t *testing.T) { t.Parallel() - in := mustVarsig[varsig.RSAVarsig](t)(varsig.RS384(0x100, varsig.PayloadEncodingDAGCBOR)) + in, err := varsig.RS384(0x100, varsig.PayloadEncodingDAGCBOR) + mustVarsig(t, in, err) out := roundTrip(t, in, "3401852420800271") assertRSAEqual(t, in, out) } @@ -43,19 +47,20 @@ func TestRS384(t *testing.T) { func TestRS512(t *testing.T) { t.Parallel() - in := mustVarsig[varsig.RSAVarsig](t)(varsig.RS512(0x100, varsig.PayloadEncodingDAGCBOR)) + in, err := varsig.RS512(0x100, varsig.PayloadEncodingDAGCBOR) + mustVarsig(t, in, err) out := roundTrip(t, in, "3401852413800271") assertRSAEqual(t, in, out) } -func assertEdDSAEqual(t *testing.T, in, out *varsig.EdDSAVarsig) { +func assertEdDSAEqual(t *testing.T, in, out varsig.EdDSAVarsig) { t.Helper() assert.Equal(t, in.Curve(), out.Curve()) assert.Equal(t, in.HashAlgorithm(), out.HashAlgorithm()) } -func assertRSAEqual(t *testing.T, in, out *varsig.RSAVarsig) { +func assertRSAEqual(t *testing.T, in, out varsig.RSAVarsig) { t.Helper() assert.Equal(t, in.HashAlgorithm(), out.HashAlgorithm()) diff --git a/constant.go b/constant.go index 82621f4..fea10f6 100644 --- a/constant.go +++ b/constant.go @@ -1,7 +1,6 @@ package varsig import ( - "bytes" "encoding/binary" "fmt" @@ -28,7 +27,7 @@ const ( // DecodeHashAlgorithm reads and validates the expected hash algorithm // (for varsig types include a variable hash algorithm.) -func DecodeHashAlgorithm(r *bytes.Reader) (HashAlgorithm, error) { +func DecodeHashAlgorithm(r BytesReader) (HashAlgorithm, error) { u, err := binary.ReadUvarint(r) if err != nil { return HashAlgorithmUnspecified, fmt.Errorf("%w: %w", ErrUnknownHashAlgorithm, err) @@ -67,7 +66,7 @@ const ( // DecodePayloadEncoding reads and validates the expected canonical payload // encoding of the data to be signed. -func DecodePayloadEncoding(r *bytes.Reader, vers Version) (PayloadEncoding, error) { +func DecodePayloadEncoding(r BytesReader, vers Version) (PayloadEncoding, error) { u, err := binary.ReadUvarint(r) if err != nil { return PayloadEncodingUnspecified, fmt.Errorf("%w: %w", ErrUnsupportedPayloadEncoding, err) diff --git a/eddsa.go b/eddsa.go index 250f823..8b24a52 100644 --- a/eddsa.go +++ b/eddsa.go @@ -1,7 +1,6 @@ package varsig import ( - "bytes" "crypto/ed25519" "encoding/binary" "fmt" @@ -27,7 +26,7 @@ const ( CurveEd448 = EdDSACurve(multicodec.Ed448Pub) ) -func decodeEdDSACurve(r *bytes.Reader) (EdDSACurve, error) { +func decodeEdDSACurve(r BytesReader) (EdDSACurve, error) { u, err := binary.ReadUvarint(r) if err != nil { return 0, err @@ -41,12 +40,12 @@ func decodeEdDSACurve(r *bytes.Reader) (EdDSACurve, error) { } } -var _ Varsig = (*EdDSAVarsig)(nil) +var _ Varsig = EdDSAVarsig{} // EdDSAVarsig is a varsig that encodes the parameters required to describe // an EdDSA signature. type EdDSAVarsig struct { - varsig[EdDSAVarsig] + varsig curve EdDSACurve hashAlg HashAlgorithm @@ -54,13 +53,13 @@ type EdDSAVarsig struct { // NewEdDSAVarsig creates and validates an EdDSA varsig with the provided // curve, hash algorithm and payload encoding. -func NewEdDSAVarsig(curve EdDSACurve, hashAlgorithm HashAlgorithm, payloadEncoding PayloadEncoding, opts ...Option) (*EdDSAVarsig, error) { +func NewEdDSAVarsig(curve EdDSACurve, hashAlgorithm HashAlgorithm, payloadEncoding PayloadEncoding, opts ...Option) (EdDSAVarsig, error) { options := newOptions(opts...) var ( vers = Version1 disc = DiscriminatorEdDSA - sig = []byte{} + sig []byte ) if options.ForceVersion0() { @@ -69,8 +68,8 @@ func NewEdDSAVarsig(curve EdDSACurve, hashAlgorithm HashAlgorithm, payloadEncodi sig = options.Signature() } - v := &EdDSAVarsig{ - varsig: varsig[EdDSAVarsig]{ + v := EdDSAVarsig{ + varsig: varsig{ vers: vers, disc: disc, payEnc: payloadEncoding, @@ -80,17 +79,21 @@ func NewEdDSAVarsig(curve EdDSACurve, hashAlgorithm HashAlgorithm, payloadEncodi hashAlg: hashAlgorithm, } - return v.validateSig(v, ed25519.PrivateKeySize) + err := v.validateSig(ed25519.SignatureSize) + if err != nil { + return EdDSAVarsig{}, err + } + return v, nil } // Curve returns the Edwards curve used to generate the EdDSA signature. -func (v *EdDSAVarsig) Curve() EdDSACurve { +func (v EdDSAVarsig) Curve() EdDSACurve { return v.curve } // HashAlgorithm returns the multicodec.Code describing the hash algorithm // used to hash the payload content before the signature is generated. -func (v *EdDSAVarsig) HashAlgorithm() HashAlgorithm { +func (v EdDSAVarsig) HashAlgorithm() HashAlgorithm { return v.hashAlg } @@ -109,7 +112,7 @@ func (v EdDSAVarsig) Encode() []byte { return buf } -func decodeEd25519(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, error) { +func decodeEd25519(r BytesReader, vers Version, disc Discriminator) (Varsig, error) { curve := EdDSACurve(disc) if vers != Version0 { var err error @@ -125,8 +128,8 @@ func decodeEd25519(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, e return nil, err } - v := &EdDSAVarsig{ - varsig: varsig[EdDSAVarsig]{ + v := EdDSAVarsig{ + varsig: varsig{ vers: vers, disc: disc, }, @@ -134,5 +137,15 @@ func decodeEd25519(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, e hashAlg: hashAlg, } - return v.decodePayEncAndSig(r, v, ed25519.PrivateKeySize) + v.payEnc, v.sig, err = v.decodePayEncAndSig(r) + if err != nil { + return nil, err + } + + err = v.validateSig(ed25519.SignatureSize) + if err != nil { + return RSAVarsig{}, err + } + + return v, nil } diff --git a/eddsa_test.go b/eddsa_test.go index 024dbc9..342f8d1 100644 --- a/eddsa_test.go +++ b/eddsa_test.go @@ -35,7 +35,7 @@ func TestDecodeEd25519(t *testing.T) { assert.Equal(t, varsig.PayloadEncodingDAGCBOR, v.PayloadEncoding()) assert.Len(t, v.Signature(), 64) - impl, ok := v.(*varsig.EdDSAVarsig) + impl, ok := v.(varsig.EdDSAVarsig) require.True(t, ok) assert.Equal(t, varsig.CurveEd25519, impl.Curve()) assert.Equal(t, varsig.HashAlgorithmSHA512, impl.HashAlgorithm()) diff --git a/error.go b/error.go index 908ed62..4603a57 100644 --- a/error.go +++ b/error.go @@ -13,7 +13,7 @@ var ErrMissingSignature = errors.New("missing signature expected in varsig v0") var ErrNotYetImplemented = errors.New("not yet implemented") // ErrUnexpectedSignaturePresent is returned when a signature is present -// in a varsig >= v1. +// in a varsig >= v1. var ErrUnexpectedSignaturePresent = errors.New("unexpected signature present in varsig >= v1") // ErrUnexpectedSignatureSize is returned when the length of the decoded @@ -21,7 +21,7 @@ var ErrUnexpectedSignaturePresent = errors.New("unexpected signature present in // signing algorithm or sent via a Varsig field. var ErrUnexpectedSignatureSize = errors.New("unexpected signature size in varsig v0") -// ErrUnknownHashAlgoritm is returned when an unexpected value is provided +// ErrUnknownHashAlgorithm is returned when an unexpected value is provided // while decoding the hashing algorithm. var ErrUnknownHashAlgorithm = errors.New("unknown hash algorithm") @@ -30,7 +30,7 @@ var ErrUnknownHashAlgorithm = errors.New("unknown hash algorithm") // for this field may vary based on the varsig version. var ErrUnsupportedPayloadEncoding = errors.New("unsupported payload encoding") -// ErrUnknowndiscorith is returned when the Registry doesn't have a +// ErrUnknownDiscriminator is returned when the Registry doesn't have a // parsing function for the decoded signing algorithm. var ErrUnknownDiscriminator = errors.New("unknown signing algorithm") @@ -43,5 +43,5 @@ var ErrUnknownEdDSACurve = errors.New("unknown Edwards curve") var ErrUnsupportedVersion = errors.New("unsupported version") // ErrBadPrefix is returned when the prefix field contains a value other -// than 0x34 (encoded as a uvarint). +// than 0x34 (encoded as an uvarint). var ErrBadPrefix = errors.New("varsig prefix not found") diff --git a/registry.go b/registry.go index ca3bc1f..e183bc0 100644 --- a/registry.go +++ b/registry.go @@ -6,11 +6,11 @@ import ( "fmt" ) -// Version represents which version of the vasig specification was used +// Version represents which version of the varsig specification was used // to produce Varsig value. type Version uint64 -// Constancts for the existing varsig specifications +// Constants for the existing varsig specifications const ( Version0 Version = 0 Version1 Version = 1 @@ -18,9 +18,9 @@ const ( // DecodeFunc is a function that parses the varsig representing a specific // signing algorithm. -type DecodeFunc func(*bytes.Reader, Version, Discriminator) (Varsig, error) +type DecodeFunc func(BytesReader, Version, Discriminator) (Varsig, error) -// Registry contains a mapping between known signing algorithms, and +// Registry contains a mapping between known signing algorithms and // functions that can parse varsigs for that signing algorithm. type Registry map[Discriminator]DecodeFunc @@ -56,7 +56,7 @@ func (rs Registry) Decode(data []byte) (Varsig, error) { // DecodeStream converts data read from the provided io.Reader into one // of the registered Varsig types. -func (rs Registry) DecodeStream(r *bytes.Reader) (Varsig, error) { +func (rs Registry) DecodeStream(r BytesReader) (Varsig, error) { pre, err := binary.ReadUvarint(r) if err != nil { return nil, fmt.Errorf("%w: %w", ErrBadPrefix, err) @@ -79,7 +79,7 @@ func (rs Registry) DecodeStream(r *bytes.Reader) (Varsig, error) { return decodeFunc(r, vers, disc) } -func (rs Registry) decodeVersAnddisc(r *bytes.Reader) (Version, Discriminator, error) { +func (rs Registry) decodeVersAnddisc(r BytesReader) (Version, Discriminator, error) { vers, err := binary.ReadUvarint(r) if err != nil { return Version(vers), 0, err @@ -98,6 +98,6 @@ func (rs Registry) decodeVersAnddisc(r *bytes.Reader) (Version, Discriminator, e return Version(vers), Discriminator(disc), err } -func notYetImplementedVarsigDecoder(_ *bytes.Reader, vers Version, disc Discriminator) (Varsig, error) { +func notYetImplementedVarsigDecoder(_ BytesReader, vers Version, disc Discriminator) (Varsig, error) { return nil, fmt.Errorf("%w: Version: %d, Discriminator: %x", ErrNotYetImplemented, vers, disc) } diff --git a/registry_test.go b/registry_test.go index da7bc99..75c6df2 100644 --- a/registry_test.go +++ b/registry_test.go @@ -61,7 +61,7 @@ func testRegistry(t *testing.T) varsig.Registry { func testDecodeFunc(t *testing.T) varsig.DecodeFunc { t.Helper() - return func(r *bytes.Reader, vers varsig.Version, disc varsig.Discriminator) (varsig.Varsig, error) { + return func(r varsig.BytesReader, vers varsig.Version, disc varsig.Discriminator) (varsig.Varsig, error) { v := &testVarsig{ vers: vers, disc: disc, @@ -71,7 +71,7 @@ func testDecodeFunc(t *testing.T) varsig.DecodeFunc { } } -var _ varsig.Varsig = (*testVarsig)(nil) +var _ varsig.Varsig = testVarsig{} type testVarsig struct { vers varsig.Version @@ -80,22 +80,22 @@ type testVarsig struct { sig []byte } -func (v *testVarsig) Version() varsig.Version { +func (v testVarsig) Version() varsig.Version { return v.vers } -func (v *testVarsig) Discriminator() varsig.Discriminator { +func (v testVarsig) Discriminator() varsig.Discriminator { return v.disc } -func (v *testVarsig) PayloadEncoding() varsig.PayloadEncoding { +func (v testVarsig) PayloadEncoding() varsig.PayloadEncoding { return v.payEnc } -func (v *testVarsig) Signature() []byte { +func (v testVarsig) Signature() []byte { return v.sig } -func (v *testVarsig) Encode() []byte { +func (v testVarsig) Encode() []byte { return nil } diff --git a/rsa.go b/rsa.go index 0aa486f..411bf8c 100644 --- a/rsa.go +++ b/rsa.go @@ -1,7 +1,6 @@ package varsig import ( - "bytes" "encoding/binary" "github.com/multiformats/go-multicodec" @@ -10,24 +9,24 @@ import ( // DiscriminatorRSA is the multicodec.Code specifying an RSA signature. const DiscriminatorRSA = Discriminator(multicodec.RsaPub) -var _ Varsig = (*RSAVarsig)(nil) +var _ Varsig = RSAVarsig{} // RSAVarsig is a varsig that encodes the parameters required to describe // an RSA signature. type RSAVarsig struct { - varsig[RSAVarsig] + varsig hashAlg HashAlgorithm sigLen uint64 } // NewRSAVarsig creates and validates an RSA varsig with the provided // hash algorithm, key length and payload encoding. -func NewRSAVarsig(hashAlgorithm HashAlgorithm, keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { +func NewRSAVarsig(hashAlgorithm HashAlgorithm, keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (RSAVarsig, error) { options := newOptions(opts...) var ( vers = Version1 - sig = []byte{} + sig []byte ) if options.ForceVersion0() { @@ -35,8 +34,8 @@ func NewRSAVarsig(hashAlgorithm HashAlgorithm, keyLength uint64, payloadEncoding sig = options.Signature() } - v := &RSAVarsig{ - varsig: varsig[RSAVarsig]{ + v := RSAVarsig{ + varsig: varsig{ vers: vers, disc: DiscriminatorRSA, payEnc: payloadEncoding, @@ -46,7 +45,12 @@ func NewRSAVarsig(hashAlgorithm HashAlgorithm, keyLength uint64, payloadEncoding sigLen: keyLength, } - return v.validateSig(v, v.sigLen) + err := v.validateSig(v.sigLen) + if err != nil { + return RSAVarsig{}, err + } + + return v, nil } // Encode returns the encoded byte format of the RSAVarsig. @@ -61,17 +65,17 @@ func (v RSAVarsig) Encode() []byte { } // HashAlgorithm returns the hash algorithm used to has the payload content. -func (v *RSAVarsig) HashAlgorithm() HashAlgorithm { +func (v RSAVarsig) HashAlgorithm() HashAlgorithm { return v.hashAlg } // KeyLength returns the length of the RSA key used to sign the payload // content. -func (v *RSAVarsig) KeyLength() uint64 { +func (v RSAVarsig) KeyLength() uint64 { return v.sigLen } -func decodeRSA(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, error) { +func decodeRSA(r BytesReader, vers Version, disc Discriminator) (Varsig, error) { hashAlg, err := DecodeHashAlgorithm(r) if err != nil { return nil, err @@ -82,8 +86,8 @@ func decodeRSA(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, error return nil, err } - vs := &RSAVarsig{ - varsig: varsig[RSAVarsig]{ + vs := RSAVarsig{ + varsig: varsig{ vers: vers, disc: disc, }, @@ -91,5 +95,15 @@ func decodeRSA(r *bytes.Reader, vers Version, disc Discriminator) (Varsig, error sigLen: sigLen, } - return vs.decodePayEncAndSig(r, vs, sigLen) + vs.payEnc, vs.sig, err = vs.decodePayEncAndSig(r) + if err != nil { + return nil, err + } + + err = vs.validateSig(vs.sigLen) + if err != nil { + return RSAVarsig{}, err + } + + return vs, nil } diff --git a/rsa_test.go b/rsa_test.go index c8aad7d..ea4f2ee 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -26,7 +26,7 @@ func TestRSAVarsig(t *testing.T) { vs, err := varsig.Decode(example) require.NoError(t, err) - rsaVs, ok := vs.(*varsig.RSAVarsig) + rsaVs, ok := vs.(varsig.RSAVarsig) require.True(t, ok) assert.Equal(t, varsig.Version1, rsaVs.Version()) diff --git a/varsig.go b/varsig.go index 883915e..d08dea4 100644 --- a/varsig.go +++ b/varsig.go @@ -19,7 +19,6 @@ package varsig import ( - "bytes" "encoding/binary" "io" ) @@ -45,11 +44,11 @@ func Decode(data []byte) (Varsig, error) { // DecodeStream converts data read from the provided io.Reader into one // of the Varsig types provided by the DefaultRegistry. -func DecodeStream(r *bytes.Reader) (Varsig, error) { +func DecodeStream(r BytesReader) (Varsig, error) { return DefaultRegistry().DecodeStream(r) } -type varsig[T Varsig] struct { +type varsig struct { vers Version disc Discriminator payEnc PayloadEncoding @@ -57,30 +56,30 @@ type varsig[T Varsig] struct { } // Version returns the varsig's version field. -func (v varsig[_]) Version() Version { +func (v varsig) Version() Version { return v.vers } -// Discriminator returns the algorithm used to produce corresponding +// Discriminator returns the algorithm used to produce the corresponding // signature. -func (v varsig[_]) Discriminator() Discriminator { +func (v varsig) Discriminator() Discriminator { return v.disc } // PayloadEncoding returns the codec that was used to encode the signed // data. -func (v varsig[_]) PayloadEncoding() PayloadEncoding { +func (v varsig) PayloadEncoding() PayloadEncoding { return v.payEnc } // Signature returns the cryptographic signature of the signed data. This // value is never present in a varsig >= v1 and must either be a valid // signature with the correct length or empty in varsig < v1. -func (v varsig[_]) Signature() []byte { +func (v varsig) Signature() []byte { return v.sig } -func (v *varsig[_]) encode() []byte { +func (v varsig) encode() []byte { var buf []byte buf = binary.AppendUvarint(buf, Prefix) @@ -94,37 +93,40 @@ func (v *varsig[_]) encode() []byte { return buf } -func (v *varsig[T]) decodePayEncAndSig(r *bytes.Reader, varsig *T, expectedLength uint64) (*T, error) { +func (v varsig) decodePayEncAndSig(r BytesReader) (PayloadEncoding, []byte, error) { payEnc, err := DecodePayloadEncoding(r, v.Version()) if err != nil { - return nil, err + return 0, nil, err } - v.payEnc = payEnc - - signature, err := io.ReadAll(r) - if err != nil { - return nil, err + var signature []byte + if v.Version() == Version0 { + signature, err = io.ReadAll(r) + if err != nil { + return 0, nil, err + } } - v.sig = signature - - return v.validateSig(varsig, expectedLength) + return payEnc, signature, nil } -func (v *varsig[T]) validateSig(varsig *T, expectedLength uint64) (*T, error) { +func (v varsig) validateSig(expectedLength uint64) error { if v.Version() == Version0 && len(v.sig) == 0 { - return varsig, ErrMissingSignature + return ErrMissingSignature } if v.Version() == Version0 && uint64(len(v.sig)) != expectedLength { - return nil, ErrUnexpectedSignatureSize + return ErrUnexpectedSignatureSize } if v.Version() == Version1 && len(v.sig) != 0 { - return nil, ErrUnexpectedSignaturePresent + return ErrUnexpectedSignaturePresent } - return varsig, nil - + return nil +} + +type BytesReader interface { + io.ByteReader + io.Reader } diff --git a/varsig_test.go b/varsig_test.go index a947bdc..d3140b9 100644 --- a/varsig_test.go +++ b/varsig_test.go @@ -158,15 +158,11 @@ func TestDecode(t *testing.T) { }) } -func mustVarsig[T varsig.Varsig](t *testing.T) func(*T, error) *T { +func mustVarsig[T varsig.Varsig](t *testing.T, v T, err error) { t.Helper() - return func(v *T, err error) *T { - if err != nil && ((*v).Version() != varsig.Version0 || !errors.Is(err, varsig.ErrMissingSignature)) { - t.Error(err) - } - - return v + if err != nil && (v.Version() != varsig.Version0 || !errors.Is(err, varsig.ErrMissingSignature)) { + t.Error(err) } }