From d3378a0608b90d243888cf7fdcd644694b3ff583 Mon Sep 17 00:00:00 2001 From: Steve Moyer Date: Fri, 4 Jul 2025 10:04:40 -0400 Subject: [PATCH] feat(registry): adds default and custom registries for signing algorithms parsing --- registry.go | 99 ++++++++++++++++++++++++++++++ registry_test.go | 153 +++++++++++++++++++++++++++++++++++++++++++++++ varsig.go | 12 +++- 3 files changed, 261 insertions(+), 3 deletions(-) create mode 100644 registry.go create mode 100644 registry_test.go diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..02cae4e --- /dev/null +++ b/registry.go @@ -0,0 +1,99 @@ +package varsig + +import ( + "bytes" + "encoding/binary" + "fmt" +) + +type Version uint64 + +const ( + Version0 Version = 0 + Version1 Version = 1 +) + +// ParseFunc is a function that parses the varsig representing a specific +// signing algorithm. +type ParseFunc func(*bytes.Reader, Version, SignAlgorithm) (Varsig, error) + +// Registry contains a mapping between known signing algorithms, and +// functions that can parse varsigs for that signing algorithm. +type Registry map[SignAlgorithm]ParseFunc + +// DefaultRegistry provides a Registry containing the mappings for the +// signing algorithms which have an implementation within this library. +func DefaultRegistry() Registry { + return map[SignAlgorithm]ParseFunc{ + SignAlgorithmRSA: parseRSA, + SignAlgorithmEd25519: notYetImplementedVarsigParser, + SignAlgorithmECDSAP256: notYetImplementedVarsigParser, + SignAlgorithmECDSASecp256k1: notYetImplementedVarsigParser, + SignAlgorithmECDSAP521: notYetImplementedVarsigParser, + } +} + +// NewRegistry creates an empty Registry. +func NewRegistry() Registry { + return make(Registry) +} + +// Register allows new mappings between a signing algorithm and its parsing +// function to the Registry. +func (rs Registry) Register(alg SignAlgorithm, parseFn ParseFunc) { + rs[alg] = parseFn +} + +// Decode converts the provided data into one of the registered Varsig +// types. +func (rs Registry) Decode(data []byte) (Varsig, error) { + return rs.DecodeStream(bytes.NewReader(data)) +} + +// DecodeStream converts data read from the provided io.Reader into one +// of the registered Varsig types. +func (rs Registry) DecodeStream(r *bytes.Reader) (Varsig, error) { + pre, err := binary.ReadUvarint(r) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrBadPrefix, err) + } + + if pre != Prefix { + return nil, fmt.Errorf("%w: expected %d, got %d", ErrBadPrefix, Prefix, pre) + } + + vers, signAlg, err := rs.parseVersAndSignAlg(r) + if err != nil { + return nil, err + } + + parseFn, ok := rs[SignAlgorithm(signAlg)] + if !ok { + return nil, fmt.Errorf("%w: %x", ErrUnknownSignAlgorithm, signAlg) + } + + return parseFn(r, vers, signAlg) +} + +func (rs Registry) parseVersAndSignAlg(r *bytes.Reader) (Version, SignAlgorithm, error) { + vers, err := binary.ReadUvarint(r) + if err != nil { + return Version(vers), 0, err + } + + if vers > 1 && vers < 64 { + return Version(vers), 0, fmt.Errorf("%w: %d", ErrUnsupportedVersion, vers) + } + + if vers >= 64 { + return 0, SignAlgorithm(vers), nil + } + + signAlg, err := binary.ReadUvarint(r) + + return Version(vers), SignAlgorithm(signAlg), err +} + +func notYetImplementedVarsigParser(_ *bytes.Reader, vers Version, signAlg SignAlgorithm) (Varsig, error) { + return nil, fmt.Errorf("%w: Version: %d, SignAlgorithm: %x", ErrNotYetImplemented, vers, signAlg) +} diff --git a/registry_test.go b/registry_test.go new file mode 100644 index 0000000..8ad96b2 --- /dev/null +++ b/registry_test.go @@ -0,0 +1,153 @@ +package varsig_test + +import ( + "bytes" + "encoding/hex" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/selesy/go-varsig" +) + +func TestRegistry_Parse(t *testing.T) { + t.Parallel() + + t.Run("passes - v0", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("348120") + require.NoError(t, err) + + reg := testRegistry(t) + + vs, err := reg.DecodeStream(bytes.NewReader(data)) + require.NoError(t, err) + assert.Equal(t, varsig.Version0, vs.Version()) + assert.Equal(t, testSignAlgorithm1, vs.SignatureAlgorithm()) + }) + + t.Run("passes - v1", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("34018120") + require.NoError(t, err) + + reg := testRegistry(t) + + vs, err := reg.DecodeStream(bytes.NewReader(data)) + require.NoError(t, err) + assert.Equal(t, varsig.Version1, vs.Version()) + assert.Equal(t, testSignAlgorithm1, vs.SignatureAlgorithm()) + }) + + t.Run("fails - no data (empty prefix)", func(t *testing.T) { + t.Parallel() + + vs, err := varsig.Decode([]byte{}) + require.ErrorIs(t, err, io.EOF) + require.ErrorIs(t, err, varsig.ErrBadPrefix) + assert.Nil(t, vs) + }) + + t.Run("fails - wrong prefix", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("42") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrBadPrefix) + assert.Nil(t, vs) + }) + + t.Run("fails - unsupported version", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("3402") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnsupportedVersion) + assert.Nil(t, vs) + }) + + t.Run("fails - unknown signature algorithm - v0", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("3464") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnknownSignAlgorithm) + assert.Nil(t, vs) + }) + + t.Run("fails - unknown signature algorithm - v1", func(t *testing.T) { + t.Parallel() + + data, err := hex.DecodeString("340164") + require.NoError(t, err) + + vs, err := varsig.Decode(data) + require.ErrorIs(t, err, varsig.ErrUnknownSignAlgorithm) + assert.Nil(t, vs) + }) +} + +const ( + testSignAlgorithm0 varsig.SignAlgorithm = 0x1000 + testSignAlgorithm1 varsig.SignAlgorithm = 0x1001 +) + +func testRegistry(t *testing.T) varsig.Registry { + t.Helper() + + reg := varsig.NewRegistry() + reg.Register(testSignAlgorithm0, testParseFunc(t)) + reg.Register(testSignAlgorithm1, testParseFunc(t)) + + return reg +} + +var _ varsig.ParseFunc = testParseFunc(&testing.T{}) + +func testParseFunc(t *testing.T) varsig.ParseFunc { + t.Helper() + + return func(r *bytes.Reader, vers varsig.Version, signAlg varsig.SignAlgorithm) (varsig.Varsig, error) { + return &testVarsig{ + vers: vers, + signAlg: signAlg, + }, nil + } +} + +var _ varsig.Varsig = (*testVarsig)(nil) + +type testVarsig struct { + vers varsig.Version + signAlg varsig.SignAlgorithm +} + +func (v *testVarsig) Version() varsig.Version { + return v.vers +} + +func (v *testVarsig) SignatureAlgorithm() varsig.SignAlgorithm { + return v.signAlg +} + +func (v *testVarsig) PayloadEncoding() varsig.PayloadEncoding { + return 0 +} + +func (v *testVarsig) Signature() []byte { + return nil +} + +func (v *testVarsig) Encode() []byte { + return nil +} diff --git a/varsig.go b/varsig.go index 466afe6..4a8908b 100644 --- a/varsig.go +++ b/varsig.go @@ -43,10 +43,16 @@ type Varsig interface { Encode() []byte } -// Decode converts the provided data into one of the registered Varsig -// types. +// Decode converts the provided data into one of the Varsig types +// provided by the DefaultRegistry. func Decode(data []byte) (Varsig, error) { - return DefaultSignAlgorithmRegistry().Decode(bytes.NewReader(data)) + return DefaultRegistry().Decode(data) +} + +// DecodeStream converts data read from the provided io.Reader into one +// of the Varsig types provided by the DefaultRegistry. +func DecodeStream(r *bytes.Reader) (Varsig, error) { + return DefaultRegistry().DecodeStream(r) } type varsig struct {