diff --git a/pkg/container/reader.go b/pkg/container/reader.go index 61402e4..db1e145 100644 --- a/pkg/container/reader.go +++ b/pkg/container/reader.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "fmt" "io" + "iter" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" @@ -42,6 +43,19 @@ func (ctn Reader) GetDelegation(cid cid.Cid) (*delegation.Token, error) { return nil, fmt.Errorf("not a delegation token") } +// GetAllDelegations returns all the delegation.Token in the container. +func (ctn Reader) GetAllDelegations() iter.Seq2[cid.Cid, *delegation.Token] { + return func(yield func(cid.Cid, *delegation.Token) bool) { + for c, t := range ctn { + if t, ok := t.(*delegation.Token); ok { + if !yield(c, t) { + return + } + } + } + } +} + // GetInvocation returns the first found invocation.Token. // If none are found, ErrNotFound is returned. func (ctn Reader) GetInvocation() (*invocation.Token, error) { diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index fa236da..b3e6aa3 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -2,14 +2,17 @@ package literal import ( + "fmt" + "reflect" + "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" ) -// TODO: remove entirely? - var Bool = basicnode.NewBool var Int = basicnode.NewInt var Float = basicnode.NewFloat @@ -26,3 +29,73 @@ func Null() ipld.Node { nb.AssignNull() return nb.Build() } + +// Map creates an IPLD node from a map[string]any +func Map[T any](m map[string]T) (ipld.Node, error) { + return qp.BuildMap(basicnode.Prototype.Any, int64(len(m)), func(ma datamodel.MapAssembler) { + for k, v := range m { + qp.MapEntry(ma, k, anyAssemble(v)) + } + }) +} + +// List creates an IPLD node from a []any +func List[T any](l []T) (ipld.Node, error) { + return qp.BuildList(basicnode.Prototype.Any, int64(len(l)), func(la datamodel.ListAssembler) { + for _, val := range l { + qp.ListEntry(la, anyAssemble(val)) + } + }) +} + +func anyAssemble(val any) qp.Assemble { + var rt reflect.Type + var rv reflect.Value + + // support for recursive calls, staying in reflection land + if cast, ok := val.(reflect.Value); ok { + rt = cast.Type() + rv = cast + } else { + rt = reflect.TypeOf(val) + rv = reflect.ValueOf(val) + } + + // we need to dereference in some cases, to get the real value type + if rt.Kind() == reflect.Ptr || rt.Kind() == reflect.Interface { + rv = rv.Elem() + rt = rv.Type() + } + + switch rt.Kind() { + case reflect.Array, reflect.Slice: + return qp.List(int64(rv.Len()), func(la datamodel.ListAssembler) { + for i := range rv.Len() { + qp.ListEntry(la, anyAssemble(rv.Index(i))) + } + }) + case reflect.Map: + if rt.Key().Kind() != reflect.String { + break + } + it := rv.MapRange() + return qp.Map(int64(rv.Len()), func(ma datamodel.MapAssembler) { + for it.Next() { + qp.MapEntry(ma, it.Key().String(), anyAssemble(it.Value())) + } + }) + case reflect.Bool: + return qp.Bool(rv.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return qp.Int(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return qp.Int(int64(rv.Uint())) + case reflect.Float32, reflect.Float64: + return qp.Float(rv.Float()) + case reflect.String: + return qp.String(rv.String()) + default: + } + + panic(fmt.Sprintf("unsupported type %T", val)) +} diff --git a/pkg/policy/match.go b/pkg/policy/match.go index 308229d..c3862d5 100644 --- a/pkg/policy/match.go +++ b/pkg/policy/match.go @@ -12,141 +12,240 @@ import ( // Match determines if the IPLD node satisfies the policy. func (p Policy) Match(node datamodel.Node) bool { for _, stmt := range p { - ok := matchStatement(stmt, node) - if !ok { + res, _ := matchStatement(stmt, node) + switch res { + case matchResultNoData, matchResultFalse: return false + case matchResultOptionalNoData, matchResultTrue: + // continue } } return true } -func matchStatement(statement Statement, node ipld.Node) bool { - switch statement.Kind() { +// PartialMatch returns false IIF one non-optional Statement has the corresponding data and doesn't match. +// If the data is missing or the non-optional Statement is matching, true is returned. +// +// This allows performing the policy checking in multiple steps, and find immediately if a Statement already failed. +// A final call to Match is necessary to make sure that the policy is fully matched, with no missing data +// (apart from optional values). +// +// The first Statement failing to match is returned as well. +func (p Policy) PartialMatch(node datamodel.Node) (bool, Statement) { + for _, stmt := range p { + res, leaf := matchStatement(stmt, node) + switch res { + case matchResultFalse: + return false, leaf + case matchResultNoData, matchResultOptionalNoData, matchResultTrue: + // continue + } + } + return true, nil +} + +type matchResult int8 + +const ( + matchResultTrue matchResult = iota // statement has data and resolve to true + matchResultFalse // statement has data and resolve to false + matchResultNoData // statement has no data + matchResultOptionalNoData // statement has no data and is optional +) + +// matchStatement evaluate the policy against the given ipld.Node and returns: +// - matchResultTrue: if the selector matched and the statement evaluated to true. +// - matchResultFalse: if the selector matched and the statement evaluated to false. +// - matchResultNoData: if the selector didn't match the expected data. +// For matchResultTrue and matchResultNoData, the leaf-most (innermost) statement failing to be true is returned, +// as well as the corresponding root-most encompassing statement. +func matchStatement(cur Statement, node ipld.Node) (_ matchResult, leafMost Statement) { + var boolToRes = func(v bool) (matchResult, Statement) { + if v { + return matchResultTrue, nil + } else { + return matchResultFalse, cur + } + } + + switch cur.Kind() { case KindEqual: - if s, ok := statement.(equality); ok { + if s, ok := cur.(equality); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur } - return datamodel.DeepEqual(s.value, res) + if res == nil { // optional selector didn't match + return matchResultOptionalNoData, nil + } + return boolToRes(datamodel.DeepEqual(s.value, res)) } case KindGreaterThan: - if s, ok := statement.(equality); ok { + if s, ok := cur.(equality); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur } - return isOrdered(s.value, res, gt) + if res == nil { // optional selector didn't match + return matchResultOptionalNoData, nil + } + return boolToRes(isOrdered(s.value, res, gt)) } case KindGreaterThanOrEqual: - if s, ok := statement.(equality); ok { + if s, ok := cur.(equality); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur } - return isOrdered(s.value, res, gte) + if res == nil { // optional selector didn't match + return matchResultOptionalNoData, nil + } + return boolToRes(isOrdered(s.value, res, gte)) } case KindLessThan: - if s, ok := statement.(equality); ok { + if s, ok := cur.(equality); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur } - return isOrdered(s.value, res, lt) + if res == nil { // optional selector didn't match + return matchResultOptionalNoData, nil + } + return boolToRes(isOrdered(s.value, res, lt)) } case KindLessThanOrEqual: - if s, ok := statement.(equality); ok { + if s, ok := cur.(equality); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur } - return isOrdered(s.value, res, lte) + if res == nil { // optional selector didn't match + return matchResultOptionalNoData, nil + } + return boolToRes(isOrdered(s.value, res, lte)) } case KindNot: - if s, ok := statement.(negation); ok { - return !matchStatement(s.statement, node) + if s, ok := cur.(negation); ok { + res, leaf := matchStatement(s.statement, node) + switch res { + case matchResultNoData, matchResultOptionalNoData: + return res, leaf + case matchResultTrue: + return matchResultFalse, leaf + case matchResultFalse: + return matchResultTrue, leaf + } } case KindAnd: - if s, ok := statement.(connective); ok { + if s, ok := cur.(connective); ok { for _, cs := range s.statements { - r := matchStatement(cs, node) - if !r { - return false + res, leaf := matchStatement(cs, node) + switch res { + case matchResultNoData, matchResultOptionalNoData: + return res, leaf + case matchResultTrue: + // continue + case matchResultFalse: + return matchResultFalse, leaf } } - return true + return matchResultTrue, nil } case KindOr: - if s, ok := statement.(connective); ok { + if s, ok := cur.(connective); ok { if len(s.statements) == 0 { - return true + return matchResultTrue, nil } for _, cs := range s.statements { - r := matchStatement(cs, node) - if r { - return true + res, leaf := matchStatement(cs, node) + switch res { + case matchResultNoData, matchResultOptionalNoData: + return res, leaf + case matchResultTrue: + return matchResultTrue, leaf + case matchResultFalse: + // continue } } - return false + return matchResultFalse, cur } case KindLike: - if s, ok := statement.(wildcard); ok { + if s, ok := cur.(wildcard); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur + } + if res == nil { // optional selector didn't match + return matchResultOptionalNoData, nil } v, err := res.AsString() if err != nil { - return false // not a string + return matchResultFalse, cur // not a string } - return s.pattern.Match(v) + return boolToRes(s.pattern.Match(v)) } case KindAll: - if s, ok := statement.(quantifier); ok { + if s, ok := cur.(quantifier); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur + } + if res == nil { + return matchResultOptionalNoData, nil } it := res.ListIterator() if it == nil { - return false // not a list + return matchResultFalse, cur // not a list } for !it.Done() { _, v, err := it.Next() if err != nil { - return false + panic("should never happen") } - ok := matchStatement(s.statement, v) - if !ok { - return false + matchRes, leaf := matchStatement(s.statement, v) + switch matchRes { + case matchResultNoData, matchResultOptionalNoData: + return matchRes, leaf + case matchResultTrue: + // continue + case matchResultFalse: + return matchResultFalse, leaf } } - return true + return matchResultTrue, nil } case KindAny: - if s, ok := statement.(quantifier); ok { + if s, ok := cur.(quantifier); ok { res, err := s.selector.Select(node) if err != nil { - return false + return matchResultNoData, cur + } + if res == nil { + return matchResultOptionalNoData, nil } it := res.ListIterator() if it == nil { - return false // not a list + return matchResultFalse, cur // not a list } for !it.Done() { _, v, err := it.Next() if err != nil { - return false + panic("should never happen") } - ok := matchStatement(s.statement, v) - if ok { - return true + matchRes, leaf := matchStatement(s.statement, v) + switch matchRes { + case matchResultNoData, matchResultOptionalNoData: + return matchRes, leaf + case matchResultTrue: + return matchResultTrue, nil + case matchResultFalse: + // continue } } - return false + return matchResultFalse, cur } } - panic(fmt.Errorf("unimplemented statement kind: %s", statement.Kind())) + panic(fmt.Errorf("unimplemented statement kind: %s", cur.Kind())) } func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool { diff --git a/pkg/policy/match_test.go b/pkg/policy/match_test.go index 9e3de4a..7d10d43 100644 --- a/pkg/policy/match_test.go +++ b/pkg/policy/match_test.go @@ -512,3 +512,380 @@ func FuzzMatch(f *testing.F) { policy.Match(dataNode) }) } + +func TestOptionalSelectors(t *testing.T) { + tests := []struct { + name string + policy Policy + data map[string]any + expected bool + }{ + { + name: "missing optional field returns true", + policy: MustConstruct(Equal(".field?", literal.String("value"))), + data: map[string]any{}, + expected: true, + }, + { + name: "present optional field with matching value returns true", + policy: MustConstruct(Equal(".field?", literal.String("value"))), + data: map[string]any{"field": "value"}, + expected: true, + }, + { + name: "present optional field with non-matching value returns false", + policy: MustConstruct(Equal(".field?", literal.String("value"))), + data: map[string]any{"field": "other"}, + expected: false, + }, + { + name: "missing non-optional field returns false", + policy: MustConstruct(Equal(".field", literal.String("value"))), + data: map[string]any{}, + expected: false, + }, + { + name: "nested missing non-optional field returns false", + policy: MustConstruct(Equal(".outer?.inner", literal.String("value"))), + data: map[string]any{"outer": map[string]any{}}, + expected: false, + }, + { + name: "completely missing nested optional path returns true", + policy: MustConstruct(Equal(".outer?.inner?", literal.String("value"))), + data: map[string]any{}, + expected: true, + }, + { + name: "partially present nested optional path with missing end returns true", + policy: MustConstruct(Equal(".outer?.inner?", literal.String("value"))), + data: map[string]any{"outer": map[string]any{}}, + expected: true, + }, + { + name: "optional array index returns true when array is empty", + policy: MustConstruct(Equal(".array[0]?", literal.String("value"))), + data: map[string]any{"array": []any{}}, + expected: true, + }, + { + name: "non-optional array index returns false when array is empty", + policy: MustConstruct(Equal(".array[0]", literal.String("value"))), + data: map[string]any{"array": []any{}}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nb := basicnode.Prototype.Map.NewBuilder() + n, err := literal.Map(tt.data) + require.NoError(t, err) + err = nb.AssignNode(n) + require.NoError(t, err) + + result := tt.policy.Match(nb.Build()) + require.Equal(t, tt.expected, result) + }) + } +} + +// The unique behaviour of PartialMatch is that it should return true for missing non-optional data (unlike Match). +func TestPartialMatch(t *testing.T) { + tests := []struct { + name string + policy Policy + data map[string]any + expectedMatch bool + expectedStmt Statement + }{ + { + name: "returns true for missing non-optional field", + policy: MustConstruct( + Equal(".field", literal.String("value")), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true when present data matches", + policy: MustConstruct( + Equal(".foo", literal.String("correct")), + Equal(".missing", literal.String("whatever")), + ), + data: map[string]any{ + "foo": "correct", + }, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns false with failing statement for present but non-matching value", + policy: MustConstruct( + Equal(".foo", literal.String("value1")), + Equal(".bar", literal.String("value2")), + ), + data: map[string]any{ + "foo": "wrong", + "bar": "value2", + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Equal(".foo", literal.String("value1")), + )[0], + }, + { + name: "continues past missing data until finding actual mismatch", + policy: MustConstruct( + Equal(".missing", literal.String("value")), + Equal(".present", literal.String("wrong")), + ), + data: map[string]any{ + "present": "actual", + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Equal(".present", literal.String("wrong")), + )[0], + }, + + // Optional fields + { + name: "returns false when optional field present but wrong", + policy: MustConstruct( + Equal(".field?", literal.String("value")), + ), + data: map[string]any{ + "field": "wrong", + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Equal(".field?", literal.String("value")), + )[0], + }, + + // Like pattern matching + { + name: "returns true for matching like pattern", + policy: MustConstruct( + Like(".pattern", "test*"), + ), + data: map[string]any{ + "pattern": "testing123", + }, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns false for non-matching like pattern", + policy: MustConstruct( + Like(".pattern", "test*"), + ), + data: map[string]any{ + "pattern": "wrong123", + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Like(".pattern", "test*"), + )[0], + }, + + // Array quantifiers + { + name: "all matches when every element satisfies condition", + policy: MustConstruct( + All(".numbers", Equal(".", literal.Int(1))), + ), + data: map[string]interface{}{ + "numbers": []interface{}{1, 1, 1}, + }, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "all fails when any element doesn't satisfy", + policy: MustConstruct( + All(".numbers", Equal(".", literal.Int(1))), + ), + data: map[string]interface{}{ + "numbers": []interface{}{1, 2, 1}, + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Equal(".", literal.Int(1)), + )[0], + }, + { + name: "any succeeds when one element matches", + policy: MustConstruct( + Any(".numbers", Equal(".", literal.Int(2))), + ), + data: map[string]interface{}{ + "numbers": []interface{}{1, 2, 3}, + }, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "any fails when no elements match", + policy: MustConstruct( + Any(".numbers", Equal(".", literal.Int(4))), + ), + data: map[string]interface{}{ + "numbers": []interface{}{1, 2, 3}, + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Any(".numbers", Equal(".", literal.Int(4))), + )[0], + }, + + // Complex nested case + { + name: "complex nested policy", + policy: MustConstruct( + And( + Equal(".required", literal.String("present")), + Equal(".optional?", literal.String("value")), + Any(".items", + And( + Equal(".name", literal.String("test")), + Like(".id", "ID*"), + ), + ), + ), + ), + data: map[string]any{ + "required": "present", + "items": []any{ + map[string]any{ + "name": "wrong", + "id": "ID123", + }, + map[string]any{ + "name": "test", + "id": "ID456", + }, + }, + }, + expectedMatch: true, + expectedStmt: nil, + }, + + // missing optional values for all the operators + { + name: "returns true for missing optional equal", + policy: MustConstruct( + Equal(".field?", literal.String("value")), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for missing optional like pattern", + policy: MustConstruct( + Like(".pattern?", "test*"), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for missing optional greater than", + policy: MustConstruct( + GreaterThan(".number?", literal.Int(5)), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for missing optional less than", + policy: MustConstruct( + LessThan(".number?", literal.Int(5)), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for missing optional array with all", + policy: MustConstruct( + All(".numbers?", Equal(".", literal.Int(1))), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for missing optional array with any", + policy: MustConstruct( + Any(".numbers?", Equal(".", literal.Int(1))), + ), + data: map[string]any{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for complex nested optional paths", + policy: MustConstruct( + And( + Equal(".required", literal.String("present")), + Any(".optional_array?", + And( + Equal(".name?", literal.String("test")), + Like(".id?", "ID*"), + ), + ), + ), + ), + data: map[string]any{ + "required": "present", + }, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true for partially present nested optional paths", + policy: MustConstruct( + And( + Equal(".required", literal.String("present")), + Any(".items", + And( + Equal(".name", literal.String("test")), + Like(".optional_id?", "ID*"), + ), + ), + ), + ), + data: map[string]any{ + "required": "present", + "items": []any{ + map[string]any{ + "name": "test", + // optional_id is missing + }, + }, + }, + expectedMatch: true, + expectedStmt: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node, err := literal.Map(tt.data) + require.NoError(t, err) + + match, stmt := tt.policy.PartialMatch(node) + require.Equal(t, tt.expectedMatch, match) + if tt.expectedStmt == nil { + require.Nil(t, stmt) + } else { + require.Equal(t, tt.expectedStmt, stmt) + } + }) + } +} diff --git a/pkg/policy/selector/parsing.go b/pkg/policy/selector/parsing.go index a432ec0..507ef77 100644 --- a/pkg/policy/selector/parsing.go +++ b/pkg/policy/selector/parsing.go @@ -9,7 +9,6 @@ import ( ) var ( - identity = Selector{segment{str: ".", identity: true}} indexRegex = regexp.MustCompile(`^-?\d+$`) sliceRegex = regexp.MustCompile(`^((\-?\d+:\-?\d*)|(\-?\d*:\-?\d+))$`) fieldRegex = regexp.MustCompile(`^\.[a-zA-Z_]*?$`) @@ -23,7 +22,7 @@ func Parse(str string) (Selector, error) { return nil, newParseError("selector must start with identity segment '.'", str, 0, string(str[0])) } if str == "." { - return identity, nil + return Selector{segment{str: ".", identity: true}}, nil } if str == ".?" { return Selector{segment{str: ".?", identity: true, optional: true}}, nil