diff --git a/common.go b/common.go new file mode 100644 index 0000000..f497347 --- /dev/null +++ b/common.go @@ -0,0 +1,25 @@ +package varsig + +// RS256 produces a varsig that describes the associated algorithm defined +// by the [IANA JOSE specification]. +// +// [IANA JOSE specidication]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms +func RS256(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { + return NewRSAVarsig(HashAlgorithmSHA256, keyLength, payloadEncoding, opts...) +} + +// RS384 produces a varsig that describes the associated algorithm defined +// by the [IANA JOSE specification]. +// +// [IANA JOSE specidication]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms +func RS384(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { + return NewRSAVarsig(HashAlgorithmSHA384, keyLength, payloadEncoding, opts...) +} + +// RS512 produces a varsig that describes the associated algorithm defined +// by the [IANA JOSE specification]. +// +// [IANA JOSE specidication]: https://www.iana.org/assignments/jose/jose.xhtml#web-signature-encryption-algorithms +func RS512(keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { + return NewRSAVarsig(HashAlgorithmSHA512, keyLength, payloadEncoding, opts...) +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..0d04a2a --- /dev/null +++ b/common_test.go @@ -0,0 +1,40 @@ +package varsig_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/selesy/go-varsig" +) + +func TestRS256(t *testing.T) { + t.Parallel() + + in := mustVarsig[varsig.RSAVarsig](t)(varsig.RS256(0x100, varsig.PayloadEncodingDAGCBOR)) + out := roundTrip(t, in, "NAGFJBKAAnE") + assertRSAEqual(t, in, out) +} + +func TestRS384(t *testing.T) { + t.Parallel() + + in := mustVarsig[varsig.RSAVarsig](t)(varsig.RS384(0x100, varsig.PayloadEncodingDAGCBOR)) + out := roundTrip(t, in, "NAGFJCCAAnE") + assertRSAEqual(t, in, out) +} + +func TestRS512(t *testing.T) { + t.Parallel() + + in := mustVarsig[varsig.RSAVarsig](t)(varsig.RS512(0x100, varsig.PayloadEncodingDAGCBOR)) + out := roundTrip(t, in, "NAGFJBOAAnE") + assertRSAEqual(t, in, out) +} + +func assertRSAEqual(t *testing.T, in, out *varsig.RSAVarsig) { + t.Helper() + + assert.Equal(t, in.HashAlgorithm(), out.HashAlgorithm()) + assert.Equal(t, in.KeyLength(), out.KeyLength()) +} diff --git a/rsa.go b/rsa.go new file mode 100644 index 0000000..917a5fa --- /dev/null +++ b/rsa.go @@ -0,0 +1,94 @@ +package varsig + +import ( + "bytes" + "encoding/binary" + + "github.com/multiformats/go-multicodec" +) + +const SignAlgorithmRSA = SignAlgorithm(multicodec.RsaPub) + +var _ Varsig = (*RSAVarsig)(nil) + +// RSAVarsig is a varsig that encodes the parameters required to describe +// and RSA signature. +type RSAVarsig struct { + varsig[RSAVarsig] + hashAlg HashAlgorithm + sigLen uint64 +} + +// NewRSAVarsig creates and validates an RSA varsig with the provided +// parameters. +func NewRSAVarsig(hashAlgorithm HashAlgorithm, keyLength uint64, payloadEncoding PayloadEncoding, opts ...Option) (*RSAVarsig, error) { + options := newOptions(opts...) + + var ( + vers = Version1 + sig = []byte{} + ) + + if options.ForceVersion0() { + vers = Version0 + sig = options.Signature() + } + + v := &RSAVarsig{ + varsig: varsig[RSAVarsig]{ + vers: vers, + signAlg: SignAlgorithmRSA, + payEnc: payloadEncoding, + sig: sig, + }, + hashAlg: hashAlgorithm, + sigLen: keyLength, + } + + return v.validateSig(v, v.sigLen) +} + +// Encode returns the encoded byte formation of the RSAVarsig. +func (v RSAVarsig) Encode() []byte { + buf := v.encode() + buf = binary.AppendUvarint(buf, uint64(v.hashAlg)) + buf = binary.AppendUvarint(buf, v.sigLen) + buf = binary.AppendUvarint(buf, uint64(v.payEnc)) + buf = append(buf, v.Signature()...) + + return buf +} + +// HashAlgorithm returns the hash algorithm used to has the payload content. +func (v *RSAVarsig) HashAlgorithm() HashAlgorithm { + return v.hashAlg +} + +// KeyLength returns the length of the RSA key used to sign the payload +// content. +func (v *RSAVarsig) KeyLength() uint64 { + return v.sigLen +} + +func decodeRSA(r *bytes.Reader, vers Version, signAlg SignAlgorithm) (Varsig, error) { + hashAlg, err := DecodeHashAlgorithm(r) + if err != nil { + return nil, err + } + + sigLen, err := binary.ReadUvarint(r) + if err != nil { + return nil, err + } + + vs := &RSAVarsig{ + varsig: varsig[RSAVarsig]{ + vers: vers, + signAlg: signAlg, + }, + hashAlg: HashAlgorithm(hashAlg), + sigLen: sigLen, + } + + return vs.decodePayEncAndSig(r, vs, sigLen) +} diff --git a/rsa_test.go b/rsa_test.go new file mode 100644 index 0000000..7fc46cd --- /dev/null +++ b/rsa_test.go @@ -0,0 +1,96 @@ +package varsig_test + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/selesy/go-varsig" +) + +func TestRSAVarsig(t *testing.T) { + t.Parallel() + + const keyLen = 0x100 + + // This test uses the same RSA configuration as below but for varsig + // >= v1 + example, err := base64.RawStdEncoding.DecodeString("NAGFJBKAAnE") + require.NoError(t, err) + + t.Run("Decode", func(t *testing.T) { + t.Parallel() + + vs, err := varsig.Decode(example) + require.NoError(t, err) + + rsaVs, ok := vs.(*varsig.RSAVarsig) + require.True(t, ok) + + assert.Equal(t, varsig.Version1, rsaVs.Version()) + assert.Equal(t, varsig.SignAlgorithmRSA, rsaVs.SignatureAlgorithm()) + assert.Equal(t, varsig.HashAlgorithmSHA256, rsaVs.HashAlgorithm()) + assert.Equal(t, varsig.PayloadEncodingDAGCBOR, rsaVs.PayloadEncoding()) + assert.Equal(t, uint64(keyLen), rsaVs.KeyLength()) + assert.Len(t, rsaVs.Signature(), 0) + }) + + t.Run("Encode", func(t *testing.T) { + t.Parallel() + + rsaVarsig, err := varsig.NewRSAVarsig( + varsig.HashAlgorithmSHA256, + keyLen, + varsig.PayloadEncodingDAGCBOR, + ) + require.NoError(t, err) + + assert.Equal(t, example, rsaVarsig.Encode()) + t.Log(base64.RawStdEncoding.EncodeToString(rsaVarsig.Encode())) + }) +} + +func TestUCANExample(t *testing.T) { + t.Parallel() + + const keyLen = 0x100 + + // This test is the value shown in the UCAN v1.0.0 example, which is + // an RSA varsig < v1 encoded as RS256 with a key length of 0x100 + // bytes and DAG-CBOR payload encoding. + example, err := base64.RawStdEncoding.DecodeString("NIUkEoACcQ") + require.NoError(t, err) + + t.Run("Decode", func(t *testing.T) { + t.Parallel() + + vs, err := varsig.Decode(example) + require.ErrorIs(t, err, varsig.ErrMissingSignature) + + rsaVs, ok := vs.(*varsig.RSAVarsig) + require.True(t, ok) + + assert.Equal(t, varsig.Version0, rsaVs.Version()) + assert.Equal(t, varsig.SignAlgorithmRSA, rsaVs.SignatureAlgorithm()) + assert.Equal(t, varsig.HashAlgorithmSHA256, rsaVs.HashAlgorithm()) + assert.Equal(t, varsig.PayloadEncodingDAGCBOR, rsaVs.PayloadEncoding()) + assert.Equal(t, uint64(keyLen), rsaVs.KeyLength()) + assert.Len(t, rsaVs.Signature(), 0) + }) + + t.Run("Encode", func(t *testing.T) { + t.Parallel() + + rsaVarsig, err := varsig.NewRSAVarsig( + varsig.HashAlgorithmSHA256, + keyLen, + varsig.PayloadEncodingDAGCBOR, + varsig.WithForceVersion0([]byte{}), + ) + require.ErrorIs(t, err, varsig.ErrMissingSignature) + + assert.Equal(t, example, rsaVarsig.Encode()) + }) +} diff --git a/varsig_test.go b/varsig_test.go index 44185b0..5d05ff7 100644 --- a/varsig_test.go +++ b/varsig_test.go @@ -1,7 +1,9 @@ package varsig_test import ( + "encoding/base64" "encoding/hex" + "errors" "io" "testing" @@ -157,17 +159,31 @@ func TestDecode(t *testing.T) { }) } -// func TestReadUvarint(t *testing.T) { -// t.Parallel() +func mustVarsig[T varsig.Varsig](t *testing.T) func(*T, error) *T { + t.Helper() -// var r io.ByteReader = &bytes.Reader{} + return func(v *T, err error) *T { + if err != nil && ((*v).Version() != varsig.Version0 || !errors.Is(err, varsig.ErrMissingSignature)) { + t.Error(err) + } -// u, err := binary.ReadUvarint(r) -// require.ErrorIs(t, err, io.EOF) -// assert.Equal(t, uint64(0), u) + return v + } +} -// var buf []byte -// buf = binary.AppendUvarint(buf, 0x100) -// t.Log("0x100 varint:", hex.EncodeToString(buf)) -// t.Fail() -// } +func roundTrip[T varsig.Varsig](t *testing.T, in T, expEncHex string) T { + data := in.Encode() + assert.Equal(t, expEncHex, base64.RawStdEncoding.EncodeToString(data)) + + out, err := varsig.Decode(in.Encode()) + if err != nil && (out.Version() != varsig.Version0 || !errors.Is(err, varsig.ErrMissingSignature)) { + t.Fail() + } + + assert.Equal(t, in.Version(), out.Version()) + assert.Equal(t, in.SignatureAlgorithm(), out.SignatureAlgorithm()) + assert.Equal(t, in.PayloadEncoding(), out.PayloadEncoding()) + assert.Equal(t, in.Signature(), out.Signature()) + + return out.(T) +}