args: simplify API + code

This commit is contained in:
Michael Muré
2024-11-12 12:14:58 +01:00
parent 633b3d210a
commit 522181b16a
5 changed files with 206 additions and 205 deletions

View File

@@ -8,40 +8,14 @@ import (
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/codec/dagcbor"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/node/bindnode"
"github.com/ipld/go-ipld-prime/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/args"
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
)
const (
argsSchema = "type Args { String : Any }"
argsName = "Args"
)
var (
once sync.Once
ts *schema.TypeSystem
err error
)
func argsType() schema.Type {
once.Do(func() {
ts, err = ipld.LoadSchemaBytes([]byte(argsSchema))
})
if err != nil {
panic(err)
}
return ts.TypeByName(argsName)
}
func argsPrototype() schema.TypedPrototype {
return bindnode.Prototype((*args.Args)(nil), argsType())
}
func TestArgs(t *testing.T) {
t.Parallel()
@@ -78,9 +52,6 @@ func TestArgs(t *testing.T) {
argsIn := args.New()
// WARNING: Do not change the order of these statements as this is the
// order which will be present when decoded from DAG-CBOR (
// per RFC7049 default canonical ordering?).
for _, a := range []struct {
key string
val any
@@ -91,8 +62,8 @@ func TestArgs(t *testing.T) {
{key: boolKey, val: expBoolVal},
{key: linkKey, val: expLinkVal},
{key: listKey, val: expListVal},
{key: nodeKey, val: expNodeVal},
{key: uintKey, val: expUintVal},
{key: nodeKey, val: expNodeVal},
{key: bytesKey, val: expBytesVal},
{key: floatKey, val: expFloatVal},
{key: stringKey, val: expStringVal},
@@ -100,21 +71,19 @@ func TestArgs(t *testing.T) {
require.NoError(t, argsIn.Add(a.key, a.val))
}
// Round-trip to DAG-CBOR here as ToIPLD/FromIPLD is only a wrapper
// Round-trip to DAG-CBOR
argsOut := roundTripThroughDAGCBOR(t, argsIn)
assert.Equal(t, argsIn, argsOut)
assert.ElementsMatch(t, argsIn.Keys, argsOut.Keys)
assert.Equal(t, argsIn.Values, argsOut.Values)
actMapVal := map[string]string{}
mit := argsOut.Values[mapKey].MapIterator()
es := errorSwallower(t)
for !mit.Done() {
k, v, err := mit.Next()
require.NoError(t, err)
ks := es(k.AsString()).(string)
vs := es(v.AsString()).(string)
ks := must(k.AsString())
vs := must(v.AsString())
actMapVal[ks] = vs
}
@@ -124,23 +93,23 @@ func TestArgs(t *testing.T) {
for !lit.Done() {
_, v, err := lit.Next()
require.NoError(t, err)
vs := es(v.AsString()).(string)
vs := must(v.AsString())
actListVal = append(actListVal, vs)
}
assert.Equal(t, expIntVal, es(argsOut.Values[intKey].AsInt()))
assert.Equal(t, expIntVal, must(argsOut.Values[intKey].AsInt()))
assert.Equal(t, expMapVal, actMapVal) // TODO: special accessor
// TODO: the nil map comes back empty (but the right type)
// assert.Equal(t, expNilVal, actNilVal)
assert.Equal(t, expBoolVal, es(argsOut.Values[boolKey].AsBool()))
assert.Equal(t, expLinkVal.String(), es(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor
assert.Equal(t, expListVal, actListVal) // TODO: special accessor
assert.Equal(t, expBoolVal, must(argsOut.Values[boolKey].AsBool()))
assert.Equal(t, expLinkVal.String(), must(argsOut.Values[linkKey].AsLink()).(datamodel.Link).String()) // TODO: special accessor
assert.Equal(t, expListVal, actListVal) // TODO: special accessor
assert.Equal(t, expNodeVal, argsOut.Values[nodeKey])
assert.Equal(t, expUintVal, uint(es(argsOut.Values[uintKey].AsInt()).(int64)))
assert.Equal(t, expBytesVal, es(argsOut.Values[bytesKey].AsBytes()))
assert.Equal(t, expFloatVal, es(argsOut.Values[floatKey].AsFloat()))
assert.Equal(t, expStringVal, es(argsOut.Values[stringKey].AsString()))
assert.Equal(t, expUintVal, uint(must(argsOut.Values[uintKey].AsInt())))
assert.Equal(t, expBytesVal, must(argsOut.Values[bytesKey].AsBytes()))
assert.Equal(t, expFloatVal, must(argsOut.Values[floatKey].AsFloat()))
assert.Equal(t, expStringVal, must(argsOut.Values[stringKey].AsString()))
}
func TestArgs_Include(t *testing.T) {
@@ -157,21 +126,33 @@ func TestArgs_Include(t *testing.T) {
argsIn.Include(argsOther)
es := errorSwallower(t)
assert.Len(t, argsIn.Values, 4)
assert.Equal(t, "val1", es(argsIn.Values["key1"].AsString()))
assert.Equal(t, "val2", es(argsIn.Values["key2"].AsString()))
assert.Equal(t, "val3", es(argsIn.Values["key3"].AsString()))
assert.Equal(t, "val4", es(argsIn.Values["key4"].AsString()))
assert.Equal(t, "val1", must(argsIn.Values["key1"].AsString()))
assert.Equal(t, "val2", must(argsIn.Values["key2"].AsString()))
assert.Equal(t, "val3", must(argsIn.Values["key3"].AsString()))
assert.Equal(t, "val4", must(argsIn.Values["key4"].AsString()))
}
func errorSwallower(t *testing.T) func(any, error) any {
return func(val any, err error) any {
require.NoError(t, err)
const (
argsSchema = "type Args { String : Any }"
argsName = "Args"
)
return val
var (
once sync.Once
ts *schema.TypeSystem
err error
)
func argsType() schema.Type {
once.Do(func() {
ts, err = ipld.LoadSchemaBytes([]byte(argsSchema))
})
if err != nil {
panic(err)
}
return ts.TypeByName(argsName)
}
func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args {
@@ -182,11 +163,17 @@ func roundTripThroughDAGCBOR(t *testing.T, argsIn *args.Args) *args.Args {
data, err := ipld.Encode(node, dagcbor.Encode)
require.NoError(t, err)
node, err = ipld.DecodeUsingPrototype(data, dagcbor.Decode, argsPrototype())
var argsOut args.Args
_, err = ipld.Unmarshal(data, dagcbor.Decode, &argsOut, argsType())
require.NoError(t, err)
argsOut, err := args.FromIPLD(node)
require.NoError(t, err)
return argsOut
return &argsOut
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}