Compare commits

...

6 Commits

Author SHA1 Message Date
Rod Vagg
d46e7f2866 v0.4.1 2023-04-04 14:08:15 +10:00
gammazero
83a0e939a4 Add unit test for unexpected eof 2023-04-04 14:08:15 +10:00
Andrew Gillis
0981f8566c Update cid.go
Co-authored-by: Rod Vagg <rod@vagg.org>
2023-04-04 14:08:15 +10:00
gammazero
166a3a6880 CidFromReader should not wrap valid EOF return.
When reading from an io.Reader that has no data, the io.EOF error should not be wrapped in ErrInvalidCid. This is not an invalid CID, and is not the same as a partial read which is indicated by io.ErrUnexpectedEOF.

This fix is needed because existing code that uses CidFromReader may check for the end of an input stream by `if err == io.EOF` instead of the preferred `if errors.Is(err, io.EOF)`, and that code break at runtime after upgrading to go-cid v0.4.0.
2023-04-04 14:08:15 +10:00
Henrique Dias
8098d66787 chore: version 0.4.0 2023-03-20 09:29:34 +01:00
Henrique Dias
b98e249130 feat: wrap parsing errors into ErrInvalidCid 2023-03-20 09:29:34 +01:00
3 changed files with 219 additions and 42 deletions

102
cid.go
View File

@@ -37,10 +37,32 @@ import (
// UnsupportedVersionString just holds an error message
const UnsupportedVersionString = "<unsupported cid version>"
// ErrInvalidCid is an error that indicates that a CID is invalid.
type ErrInvalidCid struct {
Err error
}
func (e ErrInvalidCid) Error() string {
return fmt.Sprintf("invalid cid: %s", e.Err)
}
func (e ErrInvalidCid) Unwrap() error {
return e.Err
}
func (e ErrInvalidCid) Is(err error) bool {
switch err.(type) {
case ErrInvalidCid, *ErrInvalidCid:
return true
default:
return false
}
}
var (
// ErrCidTooShort means that the cid passed to decode was not long
// enough to be a valid Cid
ErrCidTooShort = errors.New("cid too short")
ErrCidTooShort = ErrInvalidCid{errors.New("cid too short")}
// ErrInvalidEncoding means that selected encoding is not supported
// by this Cid version
@@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) {
// incorrectly detect it as CidV1 in the Version() method
dec, err := mh.Decode(mhash)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
if dec.Code != mh.SHA2_256 || dec.Length != 32 {
return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)
return Undef, ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)}
}
return Cid{string(mhash)}, nil
}
@@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) {
case Cid:
return v2, nil
default:
return Undef, fmt.Errorf("can't parse %+v as Cid", v2)
return Undef, ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)}
}
}
@@ -210,7 +232,7 @@ func Decode(v string) (Cid, error) {
if len(v) == 46 && v[:2] == "Qm" {
hash, err := mh.FromB58String(v)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
return tryNewCidV0(hash)
@@ -218,7 +240,7 @@ func Decode(v string) (Cid, error) {
_, data, err := mbase.Decode(v)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
return Cast(data)
@@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
// check encoding is valid
_, err := mbase.NewEncoder(encoding)
if err != nil {
return -1, err
return -1, ErrInvalidCid{err}
}
return encoding, nil
@@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
func Cast(data []byte) (Cid, error) {
nr, c, err := CidFromBytes(data)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
if nr != len(data) {
return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast")
return Undef, ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")}
}
return c, nil
@@ -434,7 +456,7 @@ func (c Cid) Equals(o Cid) bool {
// UnmarshalJSON parses the JSON representation of a Cid.
func (c *Cid) UnmarshalJSON(b []byte) error {
if len(b) < 2 {
return fmt.Errorf("invalid cid json blob")
return ErrInvalidCid{fmt.Errorf("invalid cid json blob")}
}
obj := struct {
CidTarget string `json:"/"`
@@ -442,7 +464,7 @@ func (c *Cid) UnmarshalJSON(b []byte) error {
objptr := &obj
err := json.Unmarshal(b, &objptr)
if err != nil {
return err
return ErrInvalidCid{err}
}
if objptr == nil {
*c = Cid{}
@@ -450,12 +472,12 @@ func (c *Cid) UnmarshalJSON(b []byte) error {
}
if obj.CidTarget == "" {
return fmt.Errorf("cid was incorrectly formatted")
return ErrInvalidCid{fmt.Errorf("cid was incorrectly formatted")}
}
out, err := Decode(obj.CidTarget)
if err != nil {
return err
return ErrInvalidCid{err}
}
*c = out
@@ -542,12 +564,12 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
if p.Version == 0 && (p.MhType != mh.SHA2_256 ||
(p.MhLength != 32 && p.MhLength != -1)) {
return Undef, fmt.Errorf("invalid v0 prefix")
return Undef, ErrInvalidCid{fmt.Errorf("invalid v0 prefix")}
}
hash, err := mh.Sum(data, p.MhType, length)
if err != nil {
return Undef, err
return Undef, ErrInvalidCid{err}
}
switch p.Version {
@@ -556,7 +578,7 @@ func (p Prefix) Sum(data []byte) (Cid, error) {
case 1:
return NewCidV1(p.Codec, hash), nil
default:
return Undef, fmt.Errorf("invalid cid version")
return Undef, ErrInvalidCid{fmt.Errorf("invalid cid version")}
}
}
@@ -586,22 +608,22 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
r := bytes.NewReader(buf)
vers, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}
codec, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}
mhtype, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}
mhlen, err := varint.ReadUvarint(r)
if err != nil {
return Prefix{}, err
return Prefix{}, ErrInvalidCid{err}
}
return Prefix{
@@ -615,12 +637,12 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
func CidFromBytes(data []byte) (int, Cid, error) {
if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 {
if len(data) < 34 {
return 0, Undef, fmt.Errorf("not enough bytes for cid v0")
return 0, Undef, ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")}
}
h, err := mh.Cast(data[:34])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}
return 34, Cid{string(h)}, nil
@@ -628,21 +650,21 @@ func CidFromBytes(data []byte) (int, Cid, error) {
vers, n, err := varint.FromUvarint(data)
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}
if vers != 1 {
return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
return 0, Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
}
_, cn, err := varint.FromUvarint(data[n:])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}
mhnr, _, err := mh.MHFromBytes(data[n+cn:])
if err != nil {
return 0, Undef, err
return 0, Undef, ErrInvalidCid{err}
}
l := n + cn + mhnr
@@ -695,6 +717,9 @@ func (r *bufByteReader) ReadByte() (byte, error) {
// It's recommended to supply a reader that buffers and implements io.ByteReader,
// as CidFromReader has to do many single-byte reads to decode varints.
// If the argument only implements io.Reader, single-byte Read calls are used instead.
//
// If the Reader is found to yield zero bytes, an io.EOF error is returned directly, in all
// other error cases, an ErrInvalidCid, wrapping the original error, is returned.
func CidFromReader(r io.Reader) (int, Cid, error) {
// 64 bytes is enough for any CIDv0,
// and it's enough for most CIDv1s in practice.
@@ -705,32 +730,37 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
vers, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
if err == io.EOF {
// First-byte read in ReadUvarint errors with io.EOF, so reader has no data.
// Subsequent reads with an EOF will return io.ErrUnexpectedEOF and be wrapped here.
return 0, Undef, err
}
return len(br.dst), Undef, ErrInvalidCid{err}
}
// If we have a CIDv0, read the rest of the bytes and cast the buffer.
if vers == mh.SHA2_256 {
if n, err := io.ReadFull(r, br.dst[1:34]); err != nil {
return len(br.dst) + n, Undef, err
return len(br.dst) + n, Undef, ErrInvalidCid{err}
}
br.dst = br.dst[:34]
h, err := mh.Cast(br.dst)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}
return len(br.dst), Cid{string(h)}, nil
}
if vers != 1 {
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
}
// CID block encoding multicodec.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}
// We could replace most of the code below with go-multihash's ReadMultihash.
@@ -741,19 +771,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// Multihash hash function code.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}
// Multihash digest length.
mhl, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}
// Refuse to make large allocations to prevent OOMs due to bugs.
const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc {
return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)}
}
// Fine to convert mhl to int, given maxDigestAlloc.
@@ -772,7 +802,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil {
// We can't use len(br.dst) here,
// as we've only read n bytes past prefixLength.
return prefixLength + n, Undef, err
return prefixLength + n, Undef, ErrInvalidCid{err}
}
// This simply ensures the multihash is valid.
@@ -780,7 +810,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// for now, it helps ensure consistency with CidFromBytes.
_, _, err = mh.MHFromBytes(br.dst[mhStart:])
if err != nil {
return len(br.dst), Undef, err
return len(br.dst), Undef, ErrInvalidCid{err}
}
return len(br.dst), Cid{string(br.dst)}, nil

View File

@@ -4,6 +4,7 @@ import (
"bytes"
crand "crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
@@ -162,6 +163,9 @@ func TestBasesMarshaling(t *testing.T) {
if err == nil {
t.Fatal("expected too-short error from ExtractEncoding")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
if ee != -1 {
t.Fatal("expected -1 from too-short ExtractEncoding")
}
@@ -227,6 +231,9 @@ func TestEmptyString(t *testing.T) {
if err == nil {
t.Fatal("shouldnt be able to parse an empty cid")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("error must be ErrInvalidCid")
}
}
func TestV0Handling(t *testing.T) {
@@ -282,6 +289,9 @@ func TestV0ErrorCases(t *testing.T) {
if err == nil {
t.Fatal("should have failed to decode that ref")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("error must be ErrInvalidCid")
}
}
func TestNewPrefixV1(t *testing.T) {
@@ -372,6 +382,9 @@ func TestInvalidV0Prefix(t *testing.T) {
if err == nil {
t.Fatalf("should error (index %d)", i)
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
}
}
@@ -381,6 +394,9 @@ func TestBadPrefix(t *testing.T) {
if err == nil {
t.Fatalf("expected error on v3 prefix Sum")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
}
func TestPrefixRoundtrip(t *testing.T) {
@@ -417,18 +433,30 @@ func TestBadPrefixFromBytes(t *testing.T) {
if err == nil {
t.Fatal("expected error for bad byte 0")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
_, err = PrefixFromBytes([]byte{0x01, 0x80})
if err == nil {
t.Fatal("expected error for bad byte 1")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
_, err = PrefixFromBytes([]byte{0x01, 0x01, 0x80})
if err == nil {
t.Fatal("expected error for bad byte 2")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
_, err = PrefixFromBytes([]byte{0x01, 0x01, 0x01, 0x80})
if err == nil {
t.Fatal("expected error for bad byte 3")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
}
func Test16BytesVarint(t *testing.T) {
@@ -455,6 +483,9 @@ func TestParse(t *testing.T) {
if !strings.Contains(err.Error(), "can't parse 123 as Cid") {
t.Fatalf("expected int error, got %s", err.Error())
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatalf("expected ErrInvalidCid, got %s", err.Error())
}
theHash := "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zR1n"
h, err := mh.FromB58String(theHash)
@@ -572,17 +603,29 @@ func TestJsonRoundTrip(t *testing.T) {
t.Fatal("cids not equal for Cid")
}
if err = actual2.UnmarshalJSON([]byte("1")); err == nil {
err = actual2.UnmarshalJSON([]byte("1"))
if err == nil {
t.Fatal("expected error for too-short JSON")
}
if err = actual2.UnmarshalJSON([]byte(`{"nope":"nope"}`)); err == nil {
t.Fatal("expected error for bad CID JSON")
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
if err = actual2.UnmarshalJSON([]byte(`bad "" json!`)); err == nil {
err = actual2.UnmarshalJSON([]byte(`{"nope":"nope"}`))
if err == nil {
t.Fatal("expected error for bad CID JSON")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
err = actual2.UnmarshalJSON([]byte(`bad "" json!`))
if err == nil {
t.Fatal("expected error for bad JSON")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("expected error to be ErrInvalidCid")
}
var actual3 Cid
enc, err = actual3.MarshalJSON()
@@ -740,6 +783,30 @@ func TestBadCidInput(t *testing.T) {
}
}
func TestFromReaderNoData(t *testing.T) {
// Reading no data from io.Reader should return io.EOF, not ErrInvalidCid.
n, cid, err := CidFromReader(bytes.NewReader(nil))
if err != io.EOF {
t.Fatal("Expected io.EOF error")
}
if cid != Undef {
t.Fatal("Expected Undef CID")
}
if n != 0 {
t.Fatal("Expected 0 data")
}
// Read byte indicatiing more data to and check error is ErrInvalidCid.
_, _, err = CidFromReader(bytes.NewReader([]byte{0x80}))
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("Expected ErrInvalidCid error")
}
// Check for expected wrapped error.
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatal("Expected error", io.ErrUnexpectedEOF)
}
}
func TestBadParse(t *testing.T) {
hash, err := mh.Sum([]byte("foobar"), mh.SHA3_256, -1)
if err != nil {
@@ -749,6 +816,9 @@ func TestBadParse(t *testing.T) {
if err == nil {
t.Fatal("expected to fail to parse an invalid CIDv1 CID")
}
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("error must be ErrInvalidCid")
}
}
func TestLoggable(t *testing.T) {
@@ -763,3 +833,80 @@ func TestLoggable(t *testing.T) {
t.Fatalf("did not get expected loggable form (got %v)", actual)
}
}
func TestErrInvalidCidIs(t *testing.T) {
for i, test := range []struct {
err error
target error
}{
{&ErrInvalidCid{}, ErrInvalidCid{}},
{ErrInvalidCid{}, &ErrInvalidCid{}},
{ErrInvalidCid{}, ErrInvalidCid{}},
{&ErrInvalidCid{}, &ErrInvalidCid{}},
} {
if !errors.Is(test.err, test.target) {
t.Fatalf("expected error to be ErrInvalidCid, case %d", i)
}
}
}
func TestErrInvalidCid(t *testing.T) {
run := func(err error) {
if err == nil {
t.Fatal("expected error")
}
if !strings.HasPrefix(err.Error(), "invalid cid: ") {
t.Fatal(`expected error message to contain "invalid cid: "`)
}
is := errors.Is(err, ErrInvalidCid{})
if !is {
t.Fatal("expected error to be ErrInvalidCid")
}
if !errors.Is(err, &ErrInvalidCid{}) {
t.Fatal("expected error to be &ErrInvalidCid")
}
}
_, err := Decode("")
run(err)
_, err = Decode("not-a-cid")
run(err)
_, err = Decode("bafyInvalid")
run(err)
_, err = Decode("QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zIII")
run(err)
_, err = Cast([]byte("invalid"))
run(err)
_, err = Parse("not-a-cid")
run(err)
_, err = Parse("bafyInvalid")
run(err)
_, err = Parse("QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zIII")
run(err)
_, err = Parse(123)
run(err)
_, _, err = CidFromBytes([]byte("invalid"))
run(err)
_, err = Prefix{}.Sum([]byte("data"))
run(err)
_, err = PrefixFromBytes([]byte{0x80})
run(err)
_, err = ExtractEncoding("invalid ")
run(err)
}

View File

@@ -1,3 +1,3 @@
{
"version": "v0.3.2"
"version": "v0.4.1"
}