From 19721027e4ba3fb5cda144d7c3694eda7ab0186e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Mur=C3=A9?= Date: Mon, 4 Nov 2024 18:27:38 +0100 Subject: [PATCH] literal: rewrite Map() to cover more types --- pkg/policy/literal/literal.go | 122 ++++++++++++++++------------------ 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index 872f6a1..64baa1d 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -3,9 +3,12 @@ package literal import ( "fmt" + "reflect" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" ) @@ -29,68 +32,61 @@ func Null() ipld.Node { // Map creates an IPLD node from a map[string]any func Map(m map[string]any) (ipld.Node, error) { - nb := basicnode.Prototype.Map.NewBuilder() - ma, err := nb.BeginMap(int64(len(m))) - if err != nil { - return nil, err - } - - for k, v := range m { - if err := ma.AssembleKey().AssignString(k); err != nil { - return nil, err + return qp.BuildMap(basicnode.Prototype.Any, int64(len(m)), func(ma datamodel.MapAssembler) { + for k, v := range m { + qp.MapEntry(ma, k, anyAssemble(v)) } - - switch x := v.(type) { - case string: - if err := ma.AssembleValue().AssignString(x); err != nil { - return nil, err - } - case []any: - lb := basicnode.Prototype.List.NewBuilder() - la, err := lb.BeginList(int64(len(x))) - if err != nil { - return nil, err - } - for _, elem := range x { - switch e := elem.(type) { - case string: - if err := la.AssembleValue().AssignString(e); err != nil { - return nil, err - } - case map[string]any: - nestedNode, err := Map(e) - if err != nil { - return nil, err - } - if err := la.AssembleValue().AssignNode(nestedNode); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("unsupported array element type: %T", elem) - } - } - if err := la.Finish(); err != nil { - return nil, err - } - if err := ma.AssembleValue().AssignNode(lb.Build()); err != nil { - return nil, err - } - case map[string]any: - nestedNode, err := Map(x) // recursive call for nested maps - if err != nil { - return nil, err - } - if err := ma.AssembleValue().AssignNode(nestedNode); err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("unsupported value type: %T", v) - } - } - - if err := ma.Finish(); err != nil { - return nil, err - } - - return nb.Build(), nil + }) +} + +func anyAssemble(val any) qp.Assemble { + var rt reflect.Type + var rv reflect.Value + + // support for recursive calls, staying in reflection land + if cast, ok := val.(reflect.Value); ok { + rt = cast.Type() + rv = cast + } else { + rt = reflect.TypeOf(val) + rv = reflect.ValueOf(val) + } + + // we need to dereference in some cases, to get the real value type + if rt.Kind() == reflect.Ptr || rt.Kind() == reflect.Interface { + rv = rv.Elem() + rt = rv.Type() + } + + switch rt.Kind() { + case reflect.Array, reflect.Slice: + return qp.List(int64(rv.Len()), func(la datamodel.ListAssembler) { + for i := range rv.Len() { + qp.ListEntry(la, anyAssemble(rv.Index(i))) + } + }) + case reflect.Map: + if rt.Key().Kind() != reflect.String { + break + } + it := rv.MapRange() + return qp.Map(int64(rv.Len()), func(ma datamodel.MapAssembler) { + for it.Next() { + qp.MapEntry(ma, it.Key().String(), anyAssemble(it.Value())) + } + }) + case reflect.Bool: + return qp.Bool(rv.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return qp.Int(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return qp.Int(int64(rv.Uint())) + case reflect.Float32, reflect.Float64: + return qp.Float(rv.Float()) + case reflect.String: + return qp.String(rv.String()) + default: + } + + panic(fmt.Sprintf("unsupported type %T", val)) }