From cc1307523283272594f9545836941ec0382a02b5 Mon Sep 17 00:00:00 2001 From: Jeromy Date: Tue, 30 Aug 2016 10:04:50 -0700 Subject: [PATCH] add handling for V0 cids --- cid.go | 49 ++++++++++++++++++++++++++++++++++- cid_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 114 insertions(+), 8 deletions(-) diff --git a/cid.go b/cid.go index 785a74b..ff2cc21 100644 --- a/cid.go +++ b/cid.go @@ -2,10 +2,14 @@ package cid import ( "encoding/binary" + "fmt" + mh "github.com/jbenet/go-multihash" mbase "github.com/multiformats/go-multibase" ) +const UnsupportedVersionString = "" + type Cid struct { Version uint64 Type uint64 @@ -13,6 +17,18 @@ type Cid struct { } func Decode(v string) (*Cid, error) { + if len(v) == 46 && v[:2] == "Qm" { + hash, err := mh.FromB58String(v) + if err != nil { + return nil, err + } + + return &Cid{ + Version: 0, + Hash: hash, + }, nil + } + _, data, err := mbase.Decode(v) if err != nil { return nil, err @@ -38,7 +54,38 @@ func Cast(data []byte) (*Cid, error) { }, nil } -func (c *Cid) Bytes() []byte { +func (c *Cid) String() string { + switch c.Version { + case 0: + return c.Hash.B58String() + case 1: + mbstr, err := mbase.Encode(mbase.Base58BTC, c.bytesV1()) + if err != nil { + panic("should not error with hardcoded mbase: " + err.Error()) + } + + return mbstr + default: + return "" + } +} + +func (c *Cid) Bytes() ([]byte, error) { + switch c.Version { + case 0: + return c.bytesV0(), nil + case 1: + return c.bytesV1(), nil + default: + return nil, fmt.Errorf("unsupported cid version") + } +} + +func (c *Cid) bytesV0() []byte { + return []byte(c.Hash) +} + +func (c *Cid) bytesV1() []byte { buf := make([]byte, 8+len(c.Hash)) n := binary.PutUvarint(buf, c.Version) n += binary.PutUvarint(buf[n:], c.Type) diff --git a/cid_test.go b/cid_test.go index 8d274e3..2a1890b 100644 --- a/cid_test.go +++ b/cid_test.go @@ -7,6 +7,20 @@ import ( mh "github.com/jbenet/go-multihash" ) +func assertEqual(t *testing.T, a, b *Cid) { + if a.Type != b.Type { + t.Fatal("mismatch on type") + } + + if a.Version != b.Version { + t.Fatal("mismatch on version") + } + + if !bytes.Equal(a.Hash, b.Hash) { + t.Fatal("multihash mismatch") + } +} + func TestBasicMarshaling(t *testing.T) { h, err := mh.Sum([]byte("TEST"), mh.SHA3, 4) if err != nil { @@ -19,22 +33,67 @@ func TestBasicMarshaling(t *testing.T) { Hash: h, } - data := cid.Bytes() + data, err := cid.Bytes() + if err != nil { + t.Fatal(err) + } out, err := Cast(data) if err != nil { t.Fatal(err) } - if out.Type != cid.Type { - t.Fatal("mismatch on type") + assertEqual(t, cid, out) + + s := cid.String() + out2, err := Decode(s) + if err != nil { + t.Fatal(err) } - if out.Version != cid.Version { - t.Fatal("mismatch on version") + assertEqual(t, cid, out2) +} + +func TestV0Handling(t *testing.T) { + old := "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zR1n" + + cid, err := Decode(old) + if err != nil { + t.Fatal(err) } - if !bytes.Equal(out.Hash, cid.Hash) { - t.Fatal("multihash mismatch") + if cid.Version != 0 { + t.Fatal("should have gotten version 0 cid") + } + + if cid.Hash.B58String() != old { + t.Fatal("marshaling roundtrip failed") + } + + if cid.String() != old { + t.Fatal("marshaling roundtrip failed") + } +} + +func TestV0ErrorCases(t *testing.T) { + badb58 := "QmdfTbBqBPQ7VNxZEYEj14VmRuZBkqFbiwReogJgS1zIII" + _, err := Decode(badb58) + if err == nil { + t.Fatal("should have failed to decode that ref") + } +} + +func TestBadVersion(t *testing.T) { + c := &Cid{ + Version: 17, + } + + if c.String() != UnsupportedVersionString { + t.Fatal("expected unsup string") + } + + _, err := c.Bytes() + if err == nil { + t.Fatal("shouldnt have succeeded in calling bytes") } }