args: simplify API + code

This commit is contained in:
Michael Muré
2024-11-12 12:14:58 +01:00
parent 633b3d210a
commit 522181b16a
5 changed files with 206 additions and 205 deletions

View File

@@ -5,41 +5,17 @@ package args
import (
"fmt"
"reflect"
"sync"
"sort"
"github.com/ipfs/go-cid"
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/node/bindnode"
"github.com/ipld/go-ipld-prime/schema"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/fluent/qp"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
)
const (
argsSchema = "type Args { String : Any }"
argsName = "Args"
)
var (
once sync.Once
ts *schema.TypeSystem
err error
)
func argsType() schema.Type {
once.Do(func() {
ts, err = ipld.LoadSchemaBytes([]byte(argsSchema))
})
if err != nil {
panic(err)
}
return ts.TypeByName(argsName)
}
var ErrUnsupported = fmt.Errorf("failure adding unsupported type to meta")
// Args are the Command's argumennts when an invocation Token is processed
// Args are the Command's arguments when an invocation Token is processed
// by the executor.
//
// This type must be compatible with the IPLD type represented by the IPLD
@@ -56,41 +32,23 @@ func New() *Args {
}
}
// FromIPLD unwraps an Args instance from an ipld.Node.
func FromIPLD(node ipld.Node) (*Args, error) {
var err error
defer func() {
err = handlePanic(recover())
}()
obj := bindnode.Unwrap(node)
args, ok := obj.(*Args)
if !ok {
err = fmt.Errorf("failed to convert to Args")
}
return args, err
}
// Add inserts a key/value pair in the Args set.
//
// Accepted types for val are: bool, string, int, int8, int16,
// int32, int64, uint, uint8, uint16, uint32, float32, float64, []byte,
// []any, map[string]any, ipld.Node and nil.
func (m *Args) Add(key string, val any) error {
if _, ok := m.Values[key]; ok {
func (a *Args) Add(key string, val any) error {
if _, ok := a.Values[key]; ok {
return fmt.Errorf("duplicate key %q", key)
}
node, err := anyNode(val)
node, err := literal.Any(val)
if err != nil {
return err
}
m.Values[key] = node
m.Keys = append(m.Keys, key)
a.Values[key] = node
a.Keys = append(a.Keys, key)
return nil
}
@@ -99,108 +57,39 @@ func (m *Args) Add(key string, val any) error {
//
// If duplicate keys are encountered, the new value is silently dropped
// without causing an error.
func (m *Args) Include(other *Args) {
func (a *Args) Include(other *Args) {
for _, key := range other.Keys {
if _, ok := m.Values[key]; ok {
if _, ok := a.Values[key]; ok {
// don't overwrite
continue
}
m.Values[key] = other.Values[key]
m.Keys = append(m.Keys, key)
a.Values[key] = other.Values[key]
a.Keys = append(a.Keys, key)
}
}
// ToIPLD wraps an instance of an Args with an ipld.Node.
func (m *Args) ToIPLD() (ipld.Node, error) {
var err error
defer func() {
err = handlePanic(recover())
}()
return bindnode.Wrap(m, argsType()), err
func (a *Args) ToIPLD() (ipld.Node, error) {
sort.Strings(a.Keys)
return qp.BuildMap(basicnode.Prototype.Any, int64(len(a.Keys)), func(ma datamodel.MapAssembler) {
for _, key := range a.Keys {
qp.MapEntry(ma, key, qp.Node(a.Values[key]))
}
})
}
func anyNode(val any) (ipld.Node, error) {
var err error
defer func() {
err = handlePanic(recover())
}()
if val == nil {
return literal.Null(), nil
// Equals tells if two Args hold the same values.
func (a *Args) Equals(other *Args) bool {
if len(a.Keys) != len(other.Keys) {
return false
}
if cast, ok := val.(ipld.Node); ok {
return cast, nil
if len(a.Values) != len(other.Values) {
return false
}
if cast, ok := val.(cid.Cid); ok {
return literal.LinkCid(cast), err
}
var rv reflect.Value
rv.Kind()
if cast, ok := val.(reflect.Value); ok {
rv = cast
} else {
rv = reflect.ValueOf(val)
}
for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface {
rv = rv.Elem()
}
switch rv.Kind() {
case reflect.Slice:
if rv.Type().Elem().Kind() == reflect.Uint8 {
return literal.Bytes(val.([]byte)), nil
for _, key := range a.Keys {
if !ipld.DeepEqual(a.Values[key], other.Values[key]) {
return false
}
l := make([]reflect.Value, rv.Len())
for i := 0; i < rv.Len(); i++ {
l[i] = rv.Index(i)
}
return literal.List(l)
case reflect.Map:
if rv.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("unsupported map key type: %s", rv.Type().Key().Name())
}
m := make(map[string]reflect.Value, rv.Len())
it := rv.MapRange()
for it.Next() {
m[it.Key().String()] = it.Value()
}
return literal.Map(m)
case reflect.String:
return literal.String(rv.String()), nil
case reflect.Bool:
return literal.Bool(rv.Bool()), nil
// reflect.Int64 may exceed the safe 53-bit limit of JavaScript
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return literal.Int(rv.Int()), nil
// reflect.Uint64 can't be safely converted to int64
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return literal.Int(int64(rv.Uint())), nil
case reflect.Float32, reflect.Float64:
return literal.Float(rv.Float()), nil
default:
return nil, fmt.Errorf("unsupported Args type: %s", rv.Type().Name())
}
}
func handlePanic(rec any) error {
if err, ok := rec.(error); ok {
return err
}
return fmt.Errorf("%v", rec)
return true
}

View File

@@ -8,40 +8,14 @@ import (
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/codec/dagcbor"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/node/bindnode"
"github.com/ipld/go-ipld-prime/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/args"
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
)
const (
argsSchema = "type Args { String : Any }"
argsName = "Args"
)
var (
once sync.Once
ts *schema.TypeSystem
err error
)
func argsType() schema.Type {
once.Do(func() {
ts, err = ipld.LoadSchemaBytes([]byte(argsSchema))
})
if err != nil {
panic(err)
}
return ts.TypeByName(argsName)
}
func argsPrototype() schema.TypedPrototype {
return bindnode.Prototype((*args.Args)(nil), argsType())
}
func TestArgs(t *testing.T) {
t.Parallel()
@@ -78,9 +52,6 @@ func TestArgs(t *testing.T) {
argsIn := args.New()
// WARNING: Do not change the order of these statements as this is the
// order which will be present when decoded from DAG-CBOR (
// per RFC7049 default canonical ordering?).
for _, a := range []struct {
key string
val any
@@ -91,8 +62,8 @@ func TestArgs(t *testing.T) {
{key: boolKey, val: expBoolVal},
{key: linkKey, val: expLinkVal},
{key: listKey, val: expListVal},
{key: nodeKey, val: expNodeVal},
{key: uintKey, val: expUintVal},
{key: nodeKey, val: expNodeVal},
{key: bytesKey, val: expBytesVal},
{key: floatKey, val: expFloatVal},
{key: stringKey, val: expStringVal},
@@ -100,21 +71,19 @@ func TestArgs(t *testing.T) {
require.NoError(t, argsIn.Add(a.key, a.val))
}
// Round-trip to DAG-CBOR here as ToIPLD/FromIPLD is only a wrapper
// Round-trip to DAG-CBOR
argsOut := roundTripThroughDAGCBOR(t, argsIn)
assert.Equal(t, argsIn, argsOut)
assert.ElementsMatch(t, argsIn.Keys, argsOut.Keys)
assert.Equal(t, argsIn.Values, argsOut.Values)
actMapVal := map[string]string{}
mit := argsOut.Values[mapKey].MapIterator()
es := errorSwallower(t)
for !mit.Done() {
k, v, err := mit.Next()
require.NoError(t, err)
ks := es(k.AsString()).(string)
vs := es(v.AsString()).(string)
ks := must(k.AsString())
vs := must(v.AsString())
actMapVal[ks] = vs
}
@@ -124,23 +93,23 @@ func TestArgs(t *testing.T) {
for !lit.Done() {
_, v, err := lit.Next()
require.NoError(t, err)
vs := es(v.AsString()).(string)
vs := must(v.AsString())
actListVal = append(actListVal, vs)
}
assert.Equal(t, expIntVal, es(argsOut.Values[intKey].AsInt()))
assert.Equal(t, expIntVal, must(argsOut.Values[intKey].AsInt()))
assert.Equal(t, expMapVal, actMapVal) // TODO: special accessor
// TODO: the nil map comes back empty (but the right type)
// assert.Equal(t, expNilVal, actNilVal)
assert.Equal(t, expBoolVal, es(argsOut.Values[boolKey].AsBool()))
assert.Equal(t, expLinkVal.String(), es(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor
assert.Equal(t, expListVal, actListVal) // TODO: special accessor
assert.Equal(t, expBoolVal, must(argsOut.Values[boolKey].AsBool()))
assert.Equal(t, expLinkVal.String(), must(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor
assert.Equal(t, expListVal, actListVal) // TODO: special accessor
assert.Equal(t, expNodeVal, argsOut.Values[nodeKey])
assert.Equal(t, expUintVal, uint(es(argsOut.Values[uintKey].AsInt()).(int64)))
assert.Equal(t, expBytesVal, es(argsOut.Values[bytesKey].AsBytes()))
assert.Equal(t, expFloatVal, es(argsOut.Values[floatKey].AsFloat()))
assert.Equal(t, expStringVal, es(argsOut.Values[stringKey].AsString()))
assert.Equal(t, expUintVal, uint(must(argsOut.Values[uintKey].AsInt())))
assert.Equal(t, expBytesVal, must(argsOut.Values[bytesKey].AsBytes()))
assert.Equal(t, expFloatVal, must(argsOut.Values[floatKey].AsFloat()))
assert.Equal(t, expStringVal, must(argsOut.Values[stringKey].AsString()))
}
func TestArgs_Include(t *testing.T) {
@@ -157,21 +126,33 @@ func TestArgs_Include(t *testing.T) {
argsIn.Include(argsOther)
es := errorSwallower(t)
assert.Len(t, argsIn.Values, 4)
assert.Equal(t, "val1", es(argsIn.Values["key1"].AsString()))
assert.Equal(t, "val2", es(argsIn.Values["key2"].AsString()))
assert.Equal(t, "val3", es(argsIn.Values["key3"].AsString()))
assert.Equal(t, "val4", es(argsIn.Values["key4"].AsString()))
assert.Equal(t, "val1", must(argsIn.Values["key1"].AsString()))
assert.Equal(t, "val2", must(argsIn.Values["key2"].AsString()))
assert.Equal(t, "val3", must(argsIn.Values["key3"].AsString()))
assert.Equal(t, "val4", must(argsIn.Values["key4"].AsString()))
}
func errorSwallower(t *testing.T) func(any, error) any {
return func(val any, err error) any {
require.NoError(t, err)
const (
argsSchema = "type Args { String : Any }"
argsName = "Args"
)
return val
var (
once sync.Once
ts *schema.TypeSystem
err error
)
func argsType() schema.Type {
once.Do(func() {
ts, err = ipld.LoadSchemaBytes([]byte(argsSchema))
})
if err != nil {
panic(err)
}
return ts.TypeByName(argsName)
}
func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args {
@@ -182,11 +163,17 @@ func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args {
data, err := ipld.Encode(node, dagcbor.Encode)
require.NoError(t, err)
node, err = ipld.DecodeUsingPrototype(data, dagcbor.Decode, argsPrototype())
var argsOut args.Args
_, err = ipld.Unmarshal(data, dagcbor.Decode, &argsOut, argsType())
require.NoError(t, err)
argsOut, err := args.FromIPLD(node)
require.NoError(t, err)
return argsOut
return &argsOut
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}

View File

@@ -55,6 +55,23 @@ func List[T any](l []T) (ipld.Node, error) {
})
}
// Any creates an IPLD node from any value
// If possible, use another dedicated function for your type for performance.
func Any(v any) (res ipld.Node, err error) {
builder := basicnode.Prototype__Any{}.NewBuilder()
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
res = nil
}
}()
anyAssemble(v)(builder)
return builder.Build(), nil
}
func anyAssemble(val any) qp.Assemble {
var rt reflect.Type
var rv reflect.Value
@@ -117,6 +134,11 @@ func anyAssemble(val any) qp.Assemble {
return qp.Float(rv.Float())
case reflect.String:
return qp.String(rv.String())
case reflect.Struct:
if rt == reflect.TypeOf(cid.Cid{}) {
c := rv.Interface().(cid.Cid)
return qp.Link(cidlink.Link{Cid: c})
}
default:
}

View File

@@ -3,7 +3,9 @@ package literal
import (
"testing"
"github.com/ipfs/go-cid"
"github.com/ipld/go-ipld-prime/datamodel"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
"github.com/ipld/go-ipld-prime/printer"
"github.com/stretchr/testify/require"
)
@@ -53,6 +55,7 @@ func TestMap(t *testing.T) {
"barbar": "foo",
},
},
"link": cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"),
})
require.NoError(t, err)
@@ -115,6 +118,104 @@ func TestMap(t *testing.T) {
string{"barbar"}: string{"foo"}
}
}`, printer.Sprint(v))
v, err = n.LookupByString("link")
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Link, v.Kind())
asLink, err := v.AsLink()
require.NoError(t, err)
require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm")))
}
func TestAny(t *testing.T) {
data := map[string]any{
"bool": true,
"string": "foobar",
"bytes": []byte{1, 2, 3, 4},
"int": 1234,
"uint": uint(12345),
"float": 1.45,
"slice": []int{1, 2, 3},
"array": [2]int{1, 2},
"map": map[string]any{
"foo": "bar",
"foofoo": map[string]string{
"barbar": "foo",
},
},
"link": cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"),
"func": func() {},
}
v, err := Any(data["bool"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Bool, v.Kind())
require.Equal(t, true, must(v.AsBool()))
v, err = Any(data["string"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_String, v.Kind())
require.Equal(t, "foobar", must(v.AsString()))
v, err = Any(data["bytes"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Bytes, v.Kind())
require.Equal(t, []byte{1, 2, 3, 4}, must(v.AsBytes()))
v, err = Any(data["int"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Int, v.Kind())
require.Equal(t, int64(1234), must(v.AsInt()))
v, err = Any(data["uint"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Int, v.Kind())
require.Equal(t, int64(12345), must(v.AsInt()))
v, err = Any(data["float"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Float, v.Kind())
require.Equal(t, 1.45, must(v.AsFloat()))
v, err = Any(data["slice"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_List, v.Kind())
require.Equal(t, int64(3), v.Length())
require.Equal(t, `list{
0: int{1}
1: int{2}
2: int{3}
}`, printer.Sprint(v))
v, err = Any(data["array"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_List, v.Kind())
require.Equal(t, int64(2), v.Length())
require.Equal(t, `list{
0: int{1}
1: int{2}
}`, printer.Sprint(v))
v, err = Any(data["map"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Map, v.Kind())
require.Equal(t, int64(2), v.Length())
require.Equal(t, `map{
string{"foo"}: string{"bar"}
string{"foofoo"}: map{
string{"barbar"}: string{"foo"}
}
}`, printer.Sprint(v))
v, err = Any(data["link"])
require.NoError(t, err)
require.Equal(t, datamodel.Kind_Link, v.Kind())
asLink, err := v.AsLink()
require.NoError(t, err)
require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm")))
v, err = Any(data["func"])
require.Error(t, err)
}
func must[T any](t T, err error) T {

View File

@@ -187,6 +187,8 @@ func FromIPLD[T Tokener](node datamodel.Node) (T, error) {
return zero, errors.New("the VarsigHeader key type doesn't match the issuer's key type")
}
// TODO: this re-encode the payload! Is there a less wasteful way?
data, err := ipld.Encode(info.sigPayloadNode, dagcbor.Encode)
if err != nil {
return zero, err