diff --git a/pkg/args/args.go b/pkg/args/args.go index 5436a37..d20616f 100644 --- a/pkg/args/args.go +++ b/pkg/args/args.go @@ -5,41 +5,17 @@ package args import ( "fmt" - "reflect" - "sync" + "sort" - "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" - "github.com/ipld/go-ipld-prime/node/bindnode" - "github.com/ipld/go-ipld-prime/schema" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ucan-wg/go-ucan/pkg/policy/literal" ) -const ( - argsSchema = "type Args { String : Any }" - argsName = "Args" -) - -var ( - once sync.Once - ts *schema.TypeSystem - err error -) - -func argsType() schema.Type { - once.Do(func() { - ts, err = ipld.LoadSchemaBytes([]byte(argsSchema)) - }) - if err != nil { - panic(err) - } - - return ts.TypeByName(argsName) -} - -var ErrUnsupported = fmt.Errorf("failure adding unsupported type to meta") - -// Args are the Command's argumennts when an invocation Token is processed +// Args are the Command's arguments when an invocation Token is processed // by the executor. // // This type must be compatible with the IPLD type represented by the IPLD @@ -56,41 +32,23 @@ func New() *Args { } } -// FromIPLD unwraps an Args instance from an ipld.Node. -func FromIPLD(node ipld.Node) (*Args, error) { - var err error - - defer func() { - err = handlePanic(recover()) - }() - - obj := bindnode.Unwrap(node) - - args, ok := obj.(*Args) - if !ok { - err = fmt.Errorf("failed to convert to Args") - } - - return args, err -} - // Add inserts a key/value pair in the Args set. // // Accepted types for val are: bool, string, int, int8, int16, // int32, int64, uint, uint8, uint16, uint32, float32, float64, []byte, // []any, map[string]any, ipld.Node and nil. -func (m *Args) Add(key string, val any) error { - if _, ok := m.Values[key]; ok { +func (a *Args) Add(key string, val any) error { + if _, ok := a.Values[key]; ok { return fmt.Errorf("duplicate key %q", key) } - node, err := anyNode(val) + node, err := literal.Any(val) if err != nil { return err } - m.Values[key] = node - m.Keys = append(m.Keys, key) + a.Values[key] = node + a.Keys = append(a.Keys, key) return nil } @@ -99,108 +57,39 @@ func (m *Args) Add(key string, val any) error { // // If duplicate keys are encountered, the new value is silently dropped // without causing an error. -func (m *Args) Include(other *Args) { +func (a *Args) Include(other *Args) { for _, key := range other.Keys { - if _, ok := m.Values[key]; ok { + if _, ok := a.Values[key]; ok { // don't overwrite continue } - m.Values[key] = other.Values[key] - m.Keys = append(m.Keys, key) + a.Values[key] = other.Values[key] + a.Keys = append(a.Keys, key) } } // ToIPLD wraps an instance of an Args with an ipld.Node. -func (m *Args) ToIPLD() (ipld.Node, error) { - var err error - - defer func() { - err = handlePanic(recover()) - }() - - return bindnode.Wrap(m, argsType()), err +func (a *Args) ToIPLD() (ipld.Node, error) { + sort.Strings(a.Keys) + return qp.BuildMap(basicnode.Prototype.Any, int64(len(a.Keys)), func(ma datamodel.MapAssembler) { + for _, key := range a.Keys { + qp.MapEntry(ma, key, qp.Node(a.Values[key])) + } + }) } -func anyNode(val any) (ipld.Node, error) { - var err error - - defer func() { - err = handlePanic(recover()) - }() - - if val == nil { - return literal.Null(), nil +// Equals tells if two Args hold the same values. +func (a *Args) Equals(other *Args) bool { + if len(a.Keys) != len(other.Keys) { + return false } - - if cast, ok := val.(ipld.Node); ok { - return cast, nil + if len(a.Values) != len(other.Values) { + return false } - - if cast, ok := val.(cid.Cid); ok { - return literal.LinkCid(cast), err - } - - var rv reflect.Value - - rv.Kind() - - if cast, ok := val.(reflect.Value); ok { - rv = cast - } else { - rv = reflect.ValueOf(val) - } - - for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { - rv = rv.Elem() - } - - switch rv.Kind() { - case reflect.Slice: - if rv.Type().Elem().Kind() == reflect.Uint8 { - return literal.Bytes(val.([]byte)), nil + for _, key := range a.Keys { + if !ipld.DeepEqual(a.Values[key], other.Values[key]) { + return false } - - l := make([]reflect.Value, rv.Len()) - - for i := 0; i < rv.Len(); i++ { - l[i] = rv.Index(i) - } - - return literal.List(l) - case reflect.Map: - if rv.Type().Key().Kind() != reflect.String { - return nil, fmt.Errorf("unsupported map key type: %s", rv.Type().Key().Name()) - } - - m := make(map[string]reflect.Value, rv.Len()) - it := rv.MapRange() - - for it.Next() { - m[it.Key().String()] = it.Value() - } - - return literal.Map(m) - case reflect.String: - return literal.String(rv.String()), nil - case reflect.Bool: - return literal.Bool(rv.Bool()), nil - // reflect.Int64 may exceed the safe 53-bit limit of JavaScript - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return literal.Int(rv.Int()), nil - // reflect.Uint64 can't be safely converted to int64 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return literal.Int(int64(rv.Uint())), nil - case reflect.Float32, reflect.Float64: - return literal.Float(rv.Float()), nil - default: - return nil, fmt.Errorf("unsupported Args type: %s", rv.Type().Name()) } -} - -func handlePanic(rec any) error { - if err, ok := rec.(error); ok { - return err - } - - return fmt.Errorf("%v", rec) + return true } diff --git a/pkg/args/args_test.go b/pkg/args/args_test.go index 2cccb82..8a1fda2 100644 --- a/pkg/args/args_test.go +++ b/pkg/args/args_test.go @@ -8,40 +8,14 @@ import ( "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagcbor" "github.com/ipld/go-ipld-prime/datamodel" - "github.com/ipld/go-ipld-prime/node/bindnode" "github.com/ipld/go-ipld-prime/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/pkg/args" "github.com/ucan-wg/go-ucan/pkg/policy/literal" ) -const ( - argsSchema = "type Args { String : Any }" - argsName = "Args" -) - -var ( - once sync.Once - ts *schema.TypeSystem - err error -) - -func argsType() schema.Type { - once.Do(func() { - ts, err = ipld.LoadSchemaBytes([]byte(argsSchema)) - }) - if err != nil { - panic(err) - } - - return ts.TypeByName(argsName) -} - -func argsPrototype() schema.TypedPrototype { - return bindnode.Prototype((*args.Args)(nil), argsType()) -} - func TestArgs(t *testing.T) { t.Parallel() @@ -78,9 +52,6 @@ func TestArgs(t *testing.T) { argsIn := args.New() - // WARNING: Do not change the order of these statements as this is the - // order which will be present when decoded from DAG-CBOR ( - // per RFC7049 default canonical ordering?). for _, a := range []struct { key string val any @@ -91,8 +62,8 @@ func TestArgs(t *testing.T) { {key: boolKey, val: expBoolVal}, {key: linkKey, val: expLinkVal}, {key: listKey, val: expListVal}, - {key: nodeKey, val: expNodeVal}, {key: uintKey, val: expUintVal}, + {key: nodeKey, val: expNodeVal}, {key: bytesKey, val: expBytesVal}, {key: floatKey, val: expFloatVal}, {key: stringKey, val: expStringVal}, @@ -100,21 +71,19 @@ func TestArgs(t *testing.T) { require.NoError(t, argsIn.Add(a.key, a.val)) } - // Round-trip to DAG-CBOR here as ToIPLD/FromIPLD is only a wrapper + // Round-trip to DAG-CBOR argsOut := roundTripThroughDAGCBOR(t, argsIn) - assert.Equal(t, argsIn, argsOut) + assert.ElementsMatch(t, argsIn.Keys, argsOut.Keys) + assert.Equal(t, argsIn.Values, argsOut.Values) actMapVal := map[string]string{} mit := argsOut.Values[mapKey].MapIterator() - es := errorSwallower(t) - for !mit.Done() { k, v, err := mit.Next() require.NoError(t, err) - ks := es(k.AsString()).(string) - vs := es(v.AsString()).(string) - + ks := must(k.AsString()) + vs := must(v.AsString()) actMapVal[ks] = vs } @@ -124,23 +93,23 @@ func TestArgs(t *testing.T) { for !lit.Done() { _, v, err := lit.Next() require.NoError(t, err) - vs := es(v.AsString()).(string) + vs := must(v.AsString()) actListVal = append(actListVal, vs) } - assert.Equal(t, expIntVal, es(argsOut.Values[intKey].AsInt())) + assert.Equal(t, expIntVal, must(argsOut.Values[intKey].AsInt())) assert.Equal(t, expMapVal, actMapVal) // TODO: special accessor // TODO: the nil map comes back empty (but the right type) // assert.Equal(t, expNilVal, actNilVal) - assert.Equal(t, expBoolVal, es(argsOut.Values[boolKey].AsBool())) - assert.Equal(t, expLinkVal.String(), es(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor - assert.Equal(t, expListVal, actListVal) // TODO: special accessor + assert.Equal(t, expBoolVal, must(argsOut.Values[boolKey].AsBool())) + assert.Equal(t, expLinkVal.String(), must(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor + assert.Equal(t, expListVal, actListVal) // TODO: special accessor assert.Equal(t, expNodeVal, argsOut.Values[nodeKey]) - assert.Equal(t, expUintVal, uint(es(argsOut.Values[uintKey].AsInt()).(int64))) - assert.Equal(t, expBytesVal, es(argsOut.Values[bytesKey].AsBytes())) - assert.Equal(t, expFloatVal, es(argsOut.Values[floatKey].AsFloat())) - assert.Equal(t, expStringVal, es(argsOut.Values[stringKey].AsString())) + assert.Equal(t, expUintVal, uint(must(argsOut.Values[uintKey].AsInt()))) + assert.Equal(t, expBytesVal, must(argsOut.Values[bytesKey].AsBytes())) + assert.Equal(t, expFloatVal, must(argsOut.Values[floatKey].AsFloat())) + assert.Equal(t, expStringVal, must(argsOut.Values[stringKey].AsString())) } func TestArgs_Include(t *testing.T) { @@ -157,21 +126,33 @@ func TestArgs_Include(t *testing.T) { argsIn.Include(argsOther) - es := errorSwallower(t) - assert.Len(t, argsIn.Values, 4) - assert.Equal(t, "val1", es(argsIn.Values["key1"].AsString())) - assert.Equal(t, "val2", es(argsIn.Values["key2"].AsString())) - assert.Equal(t, "val3", es(argsIn.Values["key3"].AsString())) - assert.Equal(t, "val4", es(argsIn.Values["key4"].AsString())) + assert.Equal(t, "val1", must(argsIn.Values["key1"].AsString())) + assert.Equal(t, "val2", must(argsIn.Values["key2"].AsString())) + assert.Equal(t, "val3", must(argsIn.Values["key3"].AsString())) + assert.Equal(t, "val4", must(argsIn.Values["key4"].AsString())) } -func errorSwallower(t *testing.T) func(any, error) any { - return func(val any, err error) any { - require.NoError(t, err) +const ( + argsSchema = "type Args { String : Any }" + argsName = "Args" +) - return val +var ( + once sync.Once + ts *schema.TypeSystem + err error +) + +func argsType() schema.Type { + once.Do(func() { + ts, err = ipld.LoadSchemaBytes([]byte(argsSchema)) + }) + if err != nil { + panic(err) } + + return ts.TypeByName(argsName) } func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args { @@ -182,11 +163,17 @@ func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args { data, err := ipld.Encode(node, dagcbor.Encode) require.NoError(t, err) - node, err = ipld.DecodeUsingPrototype(data, dagcbor.Decode, argsPrototype()) + + var argsOut args.Args + _, err = ipld.Unmarshal(data, dagcbor.Decode, &argsOut, argsType()) require.NoError(t, err) - argsOut, err := args.FromIPLD(node) - require.NoError(t, err) - - return argsOut + return &argsOut +} + +func must[T any](t T, err error) T { + if err != nil { + panic(err) + } + return t } diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index 65ef32c..5e5df8a 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -55,6 +55,23 @@ func List[T any](l []T) (ipld.Node, error) { }) } +// Any creates an IPLD node from any value +// If possible, use another dedicated function for your type for performance. +func Any(v any) (res ipld.Node, err error) { + builder := basicnode.Prototype__Any{}.NewBuilder() + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%v", r) + res = nil + } + }() + + anyAssemble(v)(builder) + + return builder.Build(), nil +} + func anyAssemble(val any) qp.Assemble { var rt reflect.Type var rv reflect.Value @@ -117,6 +134,11 @@ func anyAssemble(val any) qp.Assemble { return qp.Float(rv.Float()) case reflect.String: return qp.String(rv.String()) + case reflect.Struct: + if rt == reflect.TypeOf(cid.Cid{}) { + c := rv.Interface().(cid.Cid) + return qp.Link(cidlink.Link{Cid: c}) + } default: } diff --git a/pkg/policy/literal/literal_test.go b/pkg/policy/literal/literal_test.go index 8320c85..656b82b 100644 --- a/pkg/policy/literal/literal_test.go +++ b/pkg/policy/literal/literal_test.go @@ -3,7 +3,9 @@ package literal import ( "testing" + "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime/datamodel" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/printer" "github.com/stretchr/testify/require" ) @@ -53,6 +55,7 @@ func TestMap(t *testing.T) { "barbar": "foo", }, }, + "link": cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"), }) require.NoError(t, err) @@ -115,6 +118,104 @@ func TestMap(t *testing.T) { string{"barbar"}: string{"foo"} } }`, printer.Sprint(v)) + + v, err = n.LookupByString("link") + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Link, v.Kind()) + asLink, err := v.AsLink() + require.NoError(t, err) + require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"))) +} + +func TestAny(t *testing.T) { + data := map[string]any{ + "bool": true, + "string": "foobar", + "bytes": []byte{1, 2, 3, 4}, + "int": 1234, + "uint": uint(12345), + "float": 1.45, + "slice": []int{1, 2, 3}, + "array": [2]int{1, 2}, + "map": map[string]any{ + "foo": "bar", + "foofoo": map[string]string{ + "barbar": "foo", + }, + }, + "link": cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"), + "func": func() {}, + } + + v, err := Any(data["bool"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Bool, v.Kind()) + require.Equal(t, true, must(v.AsBool())) + + v, err = Any(data["string"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_String, v.Kind()) + require.Equal(t, "foobar", must(v.AsString())) + + v, err = Any(data["bytes"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Bytes, v.Kind()) + require.Equal(t, []byte{1, 2, 3, 4}, must(v.AsBytes())) + + v, err = Any(data["int"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Int, v.Kind()) + require.Equal(t, int64(1234), must(v.AsInt())) + + v, err = Any(data["uint"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Int, v.Kind()) + require.Equal(t, int64(12345), must(v.AsInt())) + + v, err = Any(data["float"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Float, v.Kind()) + require.Equal(t, 1.45, must(v.AsFloat())) + + v, err = Any(data["slice"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_List, v.Kind()) + require.Equal(t, int64(3), v.Length()) + require.Equal(t, `list{ + 0: int{1} + 1: int{2} + 2: int{3} +}`, printer.Sprint(v)) + + v, err = Any(data["array"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_List, v.Kind()) + require.Equal(t, int64(2), v.Length()) + require.Equal(t, `list{ + 0: int{1} + 1: int{2} +}`, printer.Sprint(v)) + + v, err = Any(data["map"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Map, v.Kind()) + require.Equal(t, int64(2), v.Length()) + require.Equal(t, `map{ + string{"foo"}: string{"bar"} + string{"foofoo"}: map{ + string{"barbar"}: string{"foo"} + } +}`, printer.Sprint(v)) + + v, err = Any(data["link"]) + require.NoError(t, err) + require.Equal(t, datamodel.Kind_Link, v.Kind()) + asLink, err := v.AsLink() + require.NoError(t, err) + require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"))) + + v, err = Any(data["func"]) + require.Error(t, err) } func must[T any](t T, err error) T { diff --git a/token/internal/envelope/ipld.go b/token/internal/envelope/ipld.go index 6b74809..6e9533d 100644 --- a/token/internal/envelope/ipld.go +++ b/token/internal/envelope/ipld.go @@ -187,6 +187,8 @@ func FromIPLD[T Tokener](node datamodel.Node) (T, error) { return zero, errors.New("the VarsigHeader key type doesn't match the issuer's key type") } + // TODO: this re-encode the payload! Is there a less wasteful way? + data, err := ipld.Encode(info.sigPayloadNode, dagcbor.Encode) if err != nil { return zero, err