From d353dfe6529c90dfbd26949f91637687747cfb9d Mon Sep 17 00:00:00 2001 From: Steve Moyer Date: Thu, 7 Nov 2024 12:58:53 -0500 Subject: [PATCH] feat(args): create a specialized type to manage invocation Arguments --- pkg/args/args.go | 206 ++++++++++++++++++++++++++++++++++++++++++ pkg/args/args_test.go | 192 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 398 insertions(+) create mode 100644 pkg/args/args.go create mode 100644 pkg/args/args_test.go diff --git a/pkg/args/args.go b/pkg/args/args.go new file mode 100644 index 0000000..5436a37 --- /dev/null +++ b/pkg/args/args.go @@ -0,0 +1,206 @@ +// Package args provides the type that represents the Arguments passed to +// a command within an invocation.Token as well as a convenient Add method +// to incrementally build the underlying map. +package args + +import ( + "fmt" + "reflect" + "sync" + + "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/node/bindnode" + "github.com/ipld/go-ipld-prime/schema" + "github.com/ucan-wg/go-ucan/pkg/policy/literal" +) + +const ( + argsSchema = "type Args { String : Any }" + argsName = "Args" +) + +var ( + once sync.Once + ts *schema.TypeSystem + err error +) + +func argsType() schema.Type { + once.Do(func() { + ts, err = ipld.LoadSchemaBytes([]byte(argsSchema)) + }) + if err != nil { + panic(err) + } + + return ts.TypeByName(argsName) +} + +var ErrUnsupported = fmt.Errorf("failure adding unsupported type to meta") + +// Args are the Command's argumennts when an invocation Token is processed +// by the executor. +// +// This type must be compatible with the IPLD type represented by the IPLD +// schema { String : Any }. +type Args struct { + Keys []string + Values map[string]ipld.Node +} + +// New returns a pointer to an initialized Args value. +func New() *Args { + return &Args{ + Values: map[string]ipld.Node{}, + } +} + +// FromIPLD unwraps an Args instance from an ipld.Node. +func FromIPLD(node ipld.Node) (*Args, error) { + var err error + + defer func() { + err = handlePanic(recover()) + }() + + obj := bindnode.Unwrap(node) + + args, ok := obj.(*Args) + if !ok { + err = fmt.Errorf("failed to convert to Args") + } + + return args, err +} + +// Add inserts a key/value pair in the Args set. +// +// Accepted types for val are: bool, string, int, int8, int16, +// int32, int64, uint, uint8, uint16, uint32, float32, float64, []byte, +// []any, map[string]any, ipld.Node and nil. +func (m *Args) Add(key string, val any) error { + if _, ok := m.Values[key]; ok { + return fmt.Errorf("duplicate key %q", key) + } + + node, err := anyNode(val) + if err != nil { + return err + } + + m.Values[key] = node + m.Keys = append(m.Keys, key) + + return nil +} + +// Include merges the provided arguments into the existing arguments. +// +// If duplicate keys are encountered, the new value is silently dropped +// without causing an error. +func (m *Args) Include(other *Args) { + for _, key := range other.Keys { + if _, ok := m.Values[key]; ok { + // don't overwrite + continue + } + m.Values[key] = other.Values[key] + m.Keys = append(m.Keys, key) + } +} + +// ToIPLD wraps an instance of an Args with an ipld.Node. +func (m *Args) ToIPLD() (ipld.Node, error) { + var err error + + defer func() { + err = handlePanic(recover()) + }() + + return bindnode.Wrap(m, argsType()), err +} + +func anyNode(val any) (ipld.Node, error) { + var err error + + defer func() { + err = handlePanic(recover()) + }() + + if val == nil { + return literal.Null(), nil + } + + if cast, ok := val.(ipld.Node); ok { + return cast, nil + } + + if cast, ok := val.(cid.Cid); ok { + return literal.LinkCid(cast), err + } + + var rv reflect.Value + + rv.Kind() + + if cast, ok := val.(reflect.Value); ok { + rv = cast + } else { + rv = reflect.ValueOf(val) + } + + for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + + switch rv.Kind() { + case reflect.Slice: + if rv.Type().Elem().Kind() == reflect.Uint8 { + return literal.Bytes(val.([]byte)), nil + } + + l := make([]reflect.Value, rv.Len()) + + for i := 0; i < rv.Len(); i++ { + l[i] = rv.Index(i) + } + + return literal.List(l) + case reflect.Map: + if rv.Type().Key().Kind() != reflect.String { + return nil, fmt.Errorf("unsupported map key type: %s", rv.Type().Key().Name()) + } + + m := make(map[string]reflect.Value, rv.Len()) + it := rv.MapRange() + + for it.Next() { + m[it.Key().String()] = it.Value() + } + + return literal.Map(m) + case reflect.String: + return literal.String(rv.String()), nil + case reflect.Bool: + return literal.Bool(rv.Bool()), nil + // reflect.Int64 may exceed the safe 53-bit limit of JavaScript + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return literal.Int(rv.Int()), nil + // reflect.Uint64 can't be safely converted to int64 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return literal.Int(int64(rv.Uint())), nil + case reflect.Float32, reflect.Float64: + return literal.Float(rv.Float()), nil + default: + return nil, fmt.Errorf("unsupported Args type: %s", rv.Type().Name()) + } +} + +func handlePanic(rec any) error { + if err, ok := rec.(error); ok { + return err + } + + return fmt.Errorf("%v", rec) +} diff --git a/pkg/args/args_test.go b/pkg/args/args_test.go new file mode 100644 index 0000000..2cccb82 --- /dev/null +++ b/pkg/args/args_test.go @@ -0,0 +1,192 @@ +package args_test + +import ( + "sync" + "testing" + + "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/codec/dagcbor" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/bindnode" + "github.com/ipld/go-ipld-prime/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/pkg/args" + "github.com/ucan-wg/go-ucan/pkg/policy/literal" +) + +const ( + argsSchema = "type Args { String : Any }" + argsName = "Args" +) + +var ( + once sync.Once + ts *schema.TypeSystem + err error +) + +func argsType() schema.Type { + once.Do(func() { + ts, err = ipld.LoadSchemaBytes([]byte(argsSchema)) + }) + if err != nil { + panic(err) + } + + return ts.TypeByName(argsName) +} + +func argsPrototype() schema.TypedPrototype { + return bindnode.Prototype((*args.Args)(nil), argsType()) +} + +func TestArgs(t *testing.T) { + t.Parallel() + + const ( + intKey = "intKey" + mapKey = "mapKey" + nilKey = "nilKey" + boolKey = "boolKey" + linkKey = "linkKey" + listKey = "listKey" + nodeKey = "nodeKey" + uintKey = "uintKey" + bytesKey = "bytesKey" + floatKey = "floatKey" + stringKey = "stringKey" + ) + + const ( + expIntVal = int64(-42) + expBoolVal = true + expUintVal = uint(42) + expStringVal = "stringVal" + ) + + var ( + expMapVal = map[string]string{"keyOne": "valOne", "keyTwo": "valTwo"} + // expNilVal = (map[string]string)(nil) + expLinkVal = cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm") + expListVal = []string{"elem1", "elem2", "elem3"} + expNodeVal = literal.String("nodeVal") + expBytesVal = []byte{0xde, 0xad, 0xbe, 0xef} + expFloatVal = 42.0 + ) + + argsIn := args.New() + + // WARNING: Do not change the order of these statements as this is the + // order which will be present when decoded from DAG-CBOR ( + // per RFC7049 default canonical ordering?). + for _, a := range []struct { + key string + val any + }{ + {key: intKey, val: expIntVal}, + {key: mapKey, val: expMapVal}, + // {key: nilKey, val: expNilVal}, + {key: boolKey, val: expBoolVal}, + {key: linkKey, val: expLinkVal}, + {key: listKey, val: expListVal}, + {key: nodeKey, val: expNodeVal}, + {key: uintKey, val: expUintVal}, + {key: bytesKey, val: expBytesVal}, + {key: floatKey, val: expFloatVal}, + {key: stringKey, val: expStringVal}, + } { + require.NoError(t, argsIn.Add(a.key, a.val)) + } + + // Round-trip to DAG-CBOR here as ToIPLD/FromIPLD is only a wrapper + argsOut := roundTripThroughDAGCBOR(t, argsIn) + assert.Equal(t, argsIn, argsOut) + + actMapVal := map[string]string{} + mit := argsOut.Values[mapKey].MapIterator() + + es := errorSwallower(t) + + for !mit.Done() { + k, v, err := mit.Next() + require.NoError(t, err) + ks := es(k.AsString()).(string) + vs := es(v.AsString()).(string) + + actMapVal[ks] = vs + } + + actListVal := []string{} + lit := argsOut.Values[listKey].ListIterator() + + for !lit.Done() { + _, v, err := lit.Next() + require.NoError(t, err) + vs := es(v.AsString()).(string) + + actListVal = append(actListVal, vs) + } + + assert.Equal(t, expIntVal, es(argsOut.Values[intKey].AsInt())) + assert.Equal(t, expMapVal, actMapVal) // TODO: special accessor + // TODO: the nil map comes back empty (but the right type) + // assert.Equal(t, expNilVal, actNilVal) + assert.Equal(t, expBoolVal, es(argsOut.Values[boolKey].AsBool())) + assert.Equal(t, expLinkVal.String(), es(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor + assert.Equal(t, expListVal, actListVal) // TODO: special accessor + assert.Equal(t, expNodeVal, argsOut.Values[nodeKey]) + assert.Equal(t, expUintVal, uint(es(argsOut.Values[uintKey].AsInt()).(int64))) + assert.Equal(t, expBytesVal, es(argsOut.Values[bytesKey].AsBytes())) + assert.Equal(t, expFloatVal, es(argsOut.Values[floatKey].AsFloat())) + assert.Equal(t, expStringVal, es(argsOut.Values[stringKey].AsString())) +} + +func TestArgs_Include(t *testing.T) { + t.Parallel() + + argsIn := args.New() + require.NoError(t, argsIn.Add("key1", "val1")) + require.NoError(t, argsIn.Add("key2", "val2")) + + argsOther := args.New() + require.NoError(t, argsOther.Add("key2", "valOther")) // This should not overwrite key2 above + require.NoError(t, argsOther.Add("key3", "val3")) + require.NoError(t, argsOther.Add("key4", "val4")) + + argsIn.Include(argsOther) + + es := errorSwallower(t) + + assert.Len(t, argsIn.Values, 4) + assert.Equal(t, "val1", es(argsIn.Values["key1"].AsString())) + assert.Equal(t, "val2", es(argsIn.Values["key2"].AsString())) + assert.Equal(t, "val3", es(argsIn.Values["key3"].AsString())) + assert.Equal(t, "val4", es(argsIn.Values["key4"].AsString())) +} + +func errorSwallower(t *testing.T) func(any, error) any { + return func(val any, err error) any { + require.NoError(t, err) + + return val + } +} + +func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args { + t.Helper() + + node, err := argsIn.ToIPLD() + require.NoError(t, err) + + data, err := ipld.Encode(node, dagcbor.Encode) + require.NoError(t, err) + node, err = ipld.DecodeUsingPrototype(data, dagcbor.Decode, argsPrototype()) + require.NoError(t, err) + + argsOut, err := args.FromIPLD(node) + require.NoError(t, err) + + return argsOut +}