diff --git a/registry_test.go b/registry_test.go index 8ad96b2..5f72b13 100644 --- a/registry_test.go +++ b/registry_test.go @@ -3,7 +3,6 @@ package varsig_test import ( "bytes" "encoding/hex" - "io" "testing" "github.com/stretchr/testify/assert" @@ -12,7 +11,7 @@ import ( "github.com/selesy/go-varsig" ) -func TestRegistry_Parse(t *testing.T) { +func TestRegistry_Decode(t *testing.T) { t.Parallel() t.Run("passes - v0", func(t *testing.T) { @@ -42,59 +41,6 @@ func TestRegistry_Parse(t *testing.T) { assert.Equal(t, varsig.Version1, vs.Version()) assert.Equal(t, testSignAlgorithm1, vs.SignatureAlgorithm()) }) - - t.Run("fails - no data (empty prefix)", func(t *testing.T) { - t.Parallel() - - vs, err := varsig.Decode([]byte{}) - require.ErrorIs(t, err, io.EOF) - require.ErrorIs(t, err, varsig.ErrBadPrefix) - assert.Nil(t, vs) - }) - - t.Run("fails - wrong prefix", func(t *testing.T) { - t.Parallel() - - data, err := hex.DecodeString("42") - require.NoError(t, err) - - vs, err := varsig.Decode(data) - require.ErrorIs(t, err, varsig.ErrBadPrefix) - assert.Nil(t, vs) - }) - - t.Run("fails - unsupported version", func(t *testing.T) { - t.Parallel() - - data, err := hex.DecodeString("3402") - require.NoError(t, err) - - vs, err := varsig.Decode(data) - require.ErrorIs(t, err, varsig.ErrUnsupportedVersion) - assert.Nil(t, vs) - }) - - t.Run("fails - unknown signature algorithm - v0", func(t *testing.T) { - t.Parallel() - - data, err := hex.DecodeString("3464") - require.NoError(t, err) - - vs, err := varsig.Decode(data) - require.ErrorIs(t, err, varsig.ErrUnknownSignAlgorithm) - assert.Nil(t, vs) - }) - - t.Run("fails - unknown signature algorithm - v1", func(t *testing.T) { - t.Parallel() - - data, err := hex.DecodeString("340164") - require.NoError(t, err) - - vs, err := varsig.Decode(data) - require.ErrorIs(t, err, varsig.ErrUnknownSignAlgorithm) - assert.Nil(t, vs) - }) } const ( @@ -112,16 +58,16 @@ func testRegistry(t *testing.T) varsig.Registry { return reg } -var _ varsig.ParseFunc = testParseFunc(&testing.T{}) - func testParseFunc(t *testing.T) varsig.ParseFunc { t.Helper() return func(r *bytes.Reader, vers varsig.Version, signAlg varsig.SignAlgorithm) (varsig.Varsig, error) { - return &testVarsig{ + v := &testVarsig{ vers: vers, signAlg: signAlg, - }, nil + } + + return v, nil } } @@ -130,6 +76,8 @@ var _ varsig.Varsig = (*testVarsig)(nil) type testVarsig struct { vers varsig.Version signAlg varsig.SignAlgorithm + payEnc varsig.PayloadEncoding + sig []byte } func (v *testVarsig) Version() varsig.Version { @@ -141,11 +89,11 @@ func (v *testVarsig) SignatureAlgorithm() varsig.SignAlgorithm { } func (v *testVarsig) PayloadEncoding() varsig.PayloadEncoding { - return 0 + return v.payEnc } func (v *testVarsig) Signature() []byte { - return nil + return v.sig } func (v *testVarsig) Encode() []byte { diff --git a/varsig.go b/varsig.go index 4a8908b..7a0fc6a 100644 --- a/varsig.go +++ b/varsig.go @@ -100,7 +100,14 @@ func (v *varsig) encode() []byte { return buf } -func (v *varsig) decodeSignature(r *bytes.Reader, varsig Varsig, expectedLength uint64) (Varsig, error) { +func (v *varsig) decodePayEncAndSig(r *bytes.Reader, varsig Varsig, expectedLength uint64) (Varsig, error) { + payEnc, err := DecodePayloadEncoding(r, v.Version()) + if err != nil { + return nil, err + } + + v.payEnc = payEnc + signature, err := io.ReadAll(r) if err != nil { return nil, err @@ -108,10 +115,10 @@ func (v *varsig) decodeSignature(r *bytes.Reader, varsig Varsig, expectedLength v.sig = signature - return v.validateSignature(varsig, expectedLength) + return v.validateSig(varsig, expectedLength) } -func (v *varsig) validateSignature(varsig Varsig, expectedLength uint64) (Varsig, error) { +func (v *varsig) validateSig(varsig Varsig, expectedLength uint64) (Varsig, error) { if v.Version() == Version0 && len(v.sig) == 0 { return varsig, ErrMissingSignature } diff --git a/varsig_test.go b/varsig_test.go new file mode 100644 index 0000000..44185b0 --- /dev/null +++ b/varsig_test.go @@ -0,0 +1,173 @@ +package varsig_test + +import ( + "encoding/hex" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/selesy/go-varsig" +) + +func TestDecode(t *testing.T) { + t.Parallel() + + t.Run("passes - section 3 example", func(t *testing.T) { + t.Skip() + t.Parallel() + data, err := hex.DecodeString("34ed01ae3784f03f9ee1163382fa6efa73b0c31ecf58c899c836709303ba4621d1e6df20e09aaa568914290b7ea124f5b38e70b9b69c7de0d216880eac885edd41c302") + require.NoError(t, err) + + // TODO + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrNotYetImplemented) + assert.Equal(t, nil, vs) + }) + + t.Run("fails - no data (empty prefix)", func(t *testing.T) { + t.Parallel() + + vs, err := varsig.Decode([]byte{}) + require.ErrorIs(t, err, io.EOF) + require.ErrorIs(t, err, varsig.ErrBadPrefix) + assert.Nil(t, vs) + }) + + t.Run("fails - wrong prefix", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("42") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrBadPrefix) + assert.Nil(t, vs) + }) + + t.Run("fails - unsupported version", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("3402") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnsupportedVersion) + assert.Nil(t, vs) + }) + + t.Run("fails - unknown signature algorithm - v0", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("3464") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnknownSignAlgorithm) + assert.Nil(t, vs) + }) + + t.Run("fails - unknown signature algorithm - v1", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("340164") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnknownSignAlgorithm) + assert.Nil(t, vs) + }) + + // The tests below this point require the RSAVarsig implementation + // in order to test the private varsig.decodePayEncAndSig method. + + const ( + rsaHex = "8524" + sha256Hex = "12" + keyLen = "8002" + rsaBaseV0 = "34" + rsaHex + sha256Hex + keyLen + rsaBaseV1 = "3401" + rsaHex + sha256Hex + keyLen + ) + + t.Run("passes - v1", func(t *testing.T) { + t.Parallel() + data, err := hex.DecodeString(rsaBaseV1 + "5f") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.NoError(t, err) + assert.NotNil(t, vs) + }) + + t.Run("fails - truncated varsig (no payload encoding)", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString(rsaBaseV1) + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnsupportedPayloadEncoding) + require.ErrorIs(t, err, io.EOF) + assert.Nil(t, vs) + }) + + t.Run("fails - unsupported payload encoding", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString(rsaBaseV1 + "42") // 0x42 is not a valid payload encoding + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnsupportedPayloadEncoding) + assert.Nil(t, vs) + }) + + t.Run("fails - unexpected signature length - v0", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString(rsaBaseV0 + "5f" + "42") // 0x42 is only a single byte - 256 bytes are expected + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnexpectedSignatureSize) + assert.Nil(t, vs) + }) + + t.Run("fails - unexpected signature present - v1", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString(rsaBaseV1 + "5f" + "42") // 0x42 is only a single byte - 256 bytes are expected + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnexpectedSignaturePresent) + assert.Nil(t, vs) + }) + + t.Run("passes with error - v0", func(t *testing.T) { + t.Parallel() + data, err := hex.DecodeString(rsaBaseV0 + "5f") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrMissingSignature) + assert.NotNil(t, vs) // varsig is still returned with just "header" + }) +} + +// func TestReadUvarint(t *testing.T) { +// t.Parallel() + +// var r io.ByteReader = &bytes.Reader{} + +// u, err := binary.ReadUvarint(r) +// require.ErrorIs(t, err, io.EOF) +// assert.Equal(t, uint64(0), u) + +// var buf []byte +// buf = binary.AppendUvarint(buf, 0x100) +// t.Log("0x100 varint:", hex.EncodeToString(buf)) +// t.Fail() +// }