diff --git a/encoder.go b/encoder.go index d1f87f0..42e753f 100644 --- a/encoder.go +++ b/encoder.go @@ -19,6 +19,16 @@ func NewEncoder(base Encoding) (Encoder, error) { return Encoder{base}, nil } +// MustNewEncoder is like NewEncoder but will panic if the encoding is +// invalid. +func MustNewEncoder(base Encoding) Encoder { + _, ok := EncodingToStr[base] + if !ok { + panic("Unsupported multibase encoding") + } + return Encoder{base} +} + // EncoderByName creates an encoder from a string, the string can // either be the multibase name or single character multibase prefix func EncoderByName(str string) (Encoder, error) { diff --git a/encoder_test.go b/encoder_test.go index d09c805..3db2c13 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -4,7 +4,7 @@ import ( "testing" ) -func TestInvalidPrefix(t *testing.T) { +func TestInvalidCode(t *testing.T) { _, err := NewEncoder('q') if err == nil { t.Error("expected failure") @@ -21,23 +21,31 @@ func TestInvalidName(t *testing.T) { } } -func TestPrefix(t *testing.T) { - for str, base := range Encodings { - prefix, err := NewEncoder(base) - if err != nil { - t.Fatalf("NewEncoder(%c) failed: %v", base, err) - } - str1, err := Encode(base, sampleBytes) +func TestEncoder(t *testing.T) { + for name, code := range Encodings { + encoder, err := NewEncoder(code) if err != nil { t.Fatal(err) } - str2 := prefix.Encode(sampleBytes) - if str1 != str2 { - t.Errorf("encoded string mismatch: %s != %s", str1, str2) - } - _, err = EncoderByName(str) + // Make sure the MustNewEncoder doesn't panic + MustNewEncoder(code) + str, err := Encode(code, sampleBytes) if err != nil { - t.Fatalf("NewEncoder(%s) failed: %v", str, err) + t.Fatal(err) + } + str2 := encoder.Encode(sampleBytes) + if str != str2 { + t.Errorf("encoded string mismatch: %s != %s", str, str2) + } + _, err = EncoderByName(name) + if err != nil { + t.Fatalf("EncoderByName(%s) failed: %v", name, err) + } + // Test that an encoder can be created from the single letter + // prefix + _, err = EncoderByName(str[0:1]) + if err != nil { + t.Fatalf("EncoderByName(%s) failed: %v", str[0:1], err) } } } diff --git a/multibase_test.go b/multibase_test.go index d688e66..f389f5c 100644 --- a/multibase_test.go +++ b/multibase_test.go @@ -7,13 +7,13 @@ import ( ) func TestMap(t *testing.T) { - for s,e := range Encodings { + for s, e := range Encodings { s2 := EncodingToStr[e] if s != s2 { t.Errorf("round trip failed on encoding map: %s != %s", s, s2) } } - for e,s := range EncodingToStr { + for e, s := range EncodingToStr { e2 := Encodings[s] if e != e2 { t.Errorf("round trip failed on encoding map: '%c' != '%c'", e, e2)