Compare commits

...

2 Commits

Author SHA1 Message Date
Rod Vagg
fc27a0c8a4 fix: error not as pointer 2023-03-17 11:05:58 +11:00
Henrique Dias
65e3baa1cb feat: wrap parsing errors into ErrInvalidCid 2023-03-16 08:44:03 +01:00
2 changed files with 72 additions and 25 deletions

72
cid.go
View File

@@ -37,10 +37,32 @@ import (
// UnsupportedVersionString just holds an error message // UnsupportedVersionString just holds an error message
const UnsupportedVersionString = "<unsupported cid version>" const UnsupportedVersionString = "<unsupported cid version>"
// ErrInvalidCid is an error that indicates that a CID is invalid.
type ErrInvalidCid struct {
Err error
}
func (e ErrInvalidCid) Error() string {
return fmt.Sprintf("invalid cid: %s", e.Err)
}
func (e ErrInvalidCid) Unwrap() error {
return e.Err
}
func (e ErrInvalidCid) Is(err error) bool {
switch err.(type) {
case ErrInvalidCid, *ErrInvalidCid:
return true
default:
return false
}
}
var ( var (
// ErrCidTooShort means that the cid passed to decode was not long // ErrCidTooShort means that the cid passed to decode was not long
// enough to be a valid Cid // enough to be a valid Cid
ErrCidTooShort = errors.New("cid too short") ErrCidTooShort = ErrInvalidCid{errors.New("cid too short")}
// ErrInvalidEncoding means that selected encoding is not supported // ErrInvalidEncoding means that selected encoding is not supported
// by this Cid version // by this Cid version
@@ -90,10 +112,10 @@ func tryNewCidV0(mhash mh.Multihash) (Cid, error) {
// incorrectly detect it as CidV1 in the Version() method // incorrectly detect it as CidV1 in the Version() method
dec, err := mh.Decode(mhash) dec, err := mh.Decode(mhash)
if err != nil { if err != nil {
return Undef, err return Undef, ErrInvalidCid{err}
} }
if dec.Code != mh.SHA2_256 || dec.Length != 32 { if dec.Code != mh.SHA2_256 || dec.Length != 32 {
return Undef, fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length) return Undef, ErrInvalidCid{fmt.Errorf("invalid hash for cidv0 %d-%d", dec.Code, dec.Length)}
} }
return Cid{string(mhash)}, nil return Cid{string(mhash)}, nil
} }
@@ -177,7 +199,7 @@ func Parse(v interface{}) (Cid, error) {
case Cid: case Cid:
return v2, nil return v2, nil
default: default:
return Undef, fmt.Errorf("can't parse %+v as Cid", v2) return Undef, ErrInvalidCid{fmt.Errorf("can't parse %+v as Cid", v2)}
} }
} }
@@ -210,7 +232,7 @@ func Decode(v string) (Cid, error) {
if len(v) == 46 && v[:2] == "Qm" { if len(v) == 46 && v[:2] == "Qm" {
hash, err := mh.FromB58String(v) hash, err := mh.FromB58String(v)
if err != nil { if err != nil {
return Undef, err return Undef, ErrInvalidCid{err}
} }
return tryNewCidV0(hash) return tryNewCidV0(hash)
@@ -218,7 +240,7 @@ func Decode(v string) (Cid, error) {
_, data, err := mbase.Decode(v) _, data, err := mbase.Decode(v)
if err != nil { if err != nil {
return Undef, err return Undef, ErrInvalidCid{err}
} }
return Cast(data) return Cast(data)
@@ -240,7 +262,7 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
// check encoding is valid // check encoding is valid
_, err := mbase.NewEncoder(encoding) _, err := mbase.NewEncoder(encoding)
if err != nil { if err != nil {
return -1, err return -1, ErrInvalidCid{err}
} }
return encoding, nil return encoding, nil
@@ -260,11 +282,11 @@ func ExtractEncoding(v string) (mbase.Encoding, error) {
func Cast(data []byte) (Cid, error) { func Cast(data []byte) (Cid, error) {
nr, c, err := CidFromBytes(data) nr, c, err := CidFromBytes(data)
if err != nil { if err != nil {
return Undef, err return Undef, ErrInvalidCid{err}
} }
if nr != len(data) { if nr != len(data) {
return Undef, fmt.Errorf("trailing bytes in data buffer passed to cid Cast") return Undef, ErrInvalidCid{fmt.Errorf("trailing bytes in data buffer passed to cid Cast")}
} }
return c, nil return c, nil
@@ -615,12 +637,12 @@ func PrefixFromBytes(buf []byte) (Prefix, error) {
func CidFromBytes(data []byte) (int, Cid, error) { func CidFromBytes(data []byte) (int, Cid, error) {
if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 { if len(data) > 2 && data[0] == mh.SHA2_256 && data[1] == 32 {
if len(data) < 34 { if len(data) < 34 {
return 0, Undef, fmt.Errorf("not enough bytes for cid v0") return 0, Undef, ErrInvalidCid{fmt.Errorf("not enough bytes for cid v0")}
} }
h, err := mh.Cast(data[:34]) h, err := mh.Cast(data[:34])
if err != nil { if err != nil {
return 0, Undef, err return 0, Undef, ErrInvalidCid{err}
} }
return 34, Cid{string(h)}, nil return 34, Cid{string(h)}, nil
@@ -628,21 +650,21 @@ func CidFromBytes(data []byte) (int, Cid, error) {
vers, n, err := varint.FromUvarint(data) vers, n, err := varint.FromUvarint(data)
if err != nil { if err != nil {
return 0, Undef, err return 0, Undef, ErrInvalidCid{err}
} }
if vers != 1 { if vers != 1 {
return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) return 0, Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
} }
_, cn, err := varint.FromUvarint(data[n:]) _, cn, err := varint.FromUvarint(data[n:])
if err != nil { if err != nil {
return 0, Undef, err return 0, Undef, ErrInvalidCid{err}
} }
mhnr, _, err := mh.MHFromBytes(data[n+cn:]) mhnr, _, err := mh.MHFromBytes(data[n+cn:])
if err != nil { if err != nil {
return 0, Undef, err return 0, Undef, ErrInvalidCid{err}
} }
l := n + cn + mhnr l := n + cn + mhnr
@@ -705,32 +727,32 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// The varint package wants a io.ByteReader, so we must wrap our io.Reader. // The varint package wants a io.ByteReader, so we must wrap our io.Reader.
vers, err := varint.ReadUvarint(br) vers, err := varint.ReadUvarint(br)
if err != nil { if err != nil {
return len(br.dst), Undef, err return len(br.dst), Undef, ErrInvalidCid{err}
} }
// If we have a CIDv0, read the rest of the bytes and cast the buffer. // If we have a CIDv0, read the rest of the bytes and cast the buffer.
if vers == mh.SHA2_256 { if vers == mh.SHA2_256 {
if n, err := io.ReadFull(r, br.dst[1:34]); err != nil { if n, err := io.ReadFull(r, br.dst[1:34]); err != nil {
return len(br.dst) + n, Undef, err return len(br.dst) + n, Undef, ErrInvalidCid{err}
} }
br.dst = br.dst[:34] br.dst = br.dst[:34]
h, err := mh.Cast(br.dst) h, err := mh.Cast(br.dst)
if err != nil { if err != nil {
return len(br.dst), Undef, err return len(br.dst), Undef, ErrInvalidCid{err}
} }
return len(br.dst), Cid{string(h)}, nil return len(br.dst), Cid{string(h)}, nil
} }
if vers != 1 { if vers != 1 {
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers) return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("expected 1 as the cid version number, got: %d", vers)}
} }
// CID block encoding multicodec. // CID block encoding multicodec.
_, err = varint.ReadUvarint(br) _, err = varint.ReadUvarint(br)
if err != nil { if err != nil {
return len(br.dst), Undef, err return len(br.dst), Undef, ErrInvalidCid{err}
} }
// We could replace most of the code below with go-multihash's ReadMultihash. // We could replace most of the code below with go-multihash's ReadMultihash.
@@ -741,19 +763,19 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// Multihash hash function code. // Multihash hash function code.
_, err = varint.ReadUvarint(br) _, err = varint.ReadUvarint(br)
if err != nil { if err != nil {
return len(br.dst), Undef, err return len(br.dst), Undef, ErrInvalidCid{err}
} }
// Multihash digest length. // Multihash digest length.
mhl, err := varint.ReadUvarint(br) mhl, err := varint.ReadUvarint(br)
if err != nil { if err != nil {
return len(br.dst), Undef, err return len(br.dst), Undef, ErrInvalidCid{err}
} }
// Refuse to make large allocations to prevent OOMs due to bugs. // Refuse to make large allocations to prevent OOMs due to bugs.
const maxDigestAlloc = 32 << 20 // 32MiB const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc { if mhl > maxDigestAlloc {
return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl) return len(br.dst), Undef, ErrInvalidCid{fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)}
} }
// Fine to convert mhl to int, given maxDigestAlloc. // Fine to convert mhl to int, given maxDigestAlloc.
@@ -772,7 +794,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil { if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil {
// We can't use len(br.dst) here, // We can't use len(br.dst) here,
// as we've only read n bytes past prefixLength. // as we've only read n bytes past prefixLength.
return prefixLength + n, Undef, err return prefixLength + n, Undef, ErrInvalidCid{err}
} }
// This simply ensures the multihash is valid. // This simply ensures the multihash is valid.
@@ -780,7 +802,7 @@ func CidFromReader(r io.Reader) (int, Cid, error) {
// for now, it helps ensure consistency with CidFromBytes. // for now, it helps ensure consistency with CidFromBytes.
_, _, err = mh.MHFromBytes(br.dst[mhStart:]) _, _, err = mh.MHFromBytes(br.dst[mhStart:])
if err != nil { if err != nil {
return len(br.dst), Undef, err return len(br.dst), Undef, ErrInvalidCid{err}
} }
return len(br.dst), Cid{string(br.dst)}, nil return len(br.dst), Cid{string(br.dst)}, nil

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
crand "crypto/rand" crand "crypto/rand"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
@@ -227,6 +228,9 @@ func TestEmptyString(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("shouldnt be able to parse an empty cid") t.Fatal("shouldnt be able to parse an empty cid")
} }
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("error must be ErrInvalidCid")
}
} }
func TestV0Handling(t *testing.T) { func TestV0Handling(t *testing.T) {
@@ -282,6 +286,9 @@ func TestV0ErrorCases(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("should have failed to decode that ref") t.Fatal("should have failed to decode that ref")
} }
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("error must be ErrInvalidCid")
}
} }
func TestNewPrefixV1(t *testing.T) { func TestNewPrefixV1(t *testing.T) {
@@ -455,6 +462,9 @@ func TestParse(t *testing.T) {
if !strings.Contains(err.Error(), "can't parse 123 as Cid") { if !strings.Contains(err.Error(), "can't parse 123 as Cid") {
t.Fatalf("expected int error, got %s", err.Error()) t.Fatalf("expected int error, got %s", err.Error())
} }
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatalf("expected ErrInvalidCid, got %s", err.Error())
}
theHash := "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zR1n" theHash := "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zR1n"
h, err := mh.FromB58String(theHash) h, err := mh.FromB58String(theHash)
@@ -749,6 +759,9 @@ func TestBadParse(t *testing.T) {
if err == nil { if err == nil {
t.Fatal("expected to fail to parse an invalid CIDv1 CID") t.Fatal("expected to fail to parse an invalid CIDv1 CID")
} }
if !errors.Is(err, ErrInvalidCid{}) {
t.Fatal("error must be ErrInvalidCid")
}
} }
func TestLoggable(t *testing.T) { func TestLoggable(t *testing.T) {
@@ -763,3 +776,15 @@ func TestLoggable(t *testing.T) {
t.Fatalf("did not get expected loggable form (got %v)", actual) t.Fatalf("did not get expected loggable form (got %v)", actual)
} }
} }
func TestErrInvalidCid(t *testing.T) {
_, err := Decode("not-a-cid")
if err == nil {
t.Fatal("expected error")
}
is := errors.Is(err, ErrInvalidCid{})
if !is {
t.Fatal("expected error to be ErrInvalidCid")
}
}