diff --git a/match.go b/match.go index afeb7f6..5cf07a5 100644 --- a/match.go +++ b/match.go @@ -15,7 +15,7 @@ func Match(policy Policy, node ipld.Node) bool { for _, stmt := range policy { ok := matchStatement(stmt, node) if !ok { - return ok + return false } } return true @@ -103,7 +103,33 @@ func matchStatement(statement Statement, node ipld.Node) bool { return s.Value().Match(v) } case Kind_All: + if s, ok := statement.(QuantifierStatement); ok { + _, many, err := selector.Select(s.Selector(), node) + if err != nil || many == nil { + return false + } + for _, n := range many { + ok := Match(s.Value(), n) + if !ok { + return false + } + } + return true + } case Kind_Any: + if s, ok := statement.(QuantifierStatement); ok { + _, many, err := selector.Select(s.Selector(), node) + if err != nil || many == nil { + return false + } + for _, n := range many { + ok := Match(s.Value(), n) + if ok { + return true + } + } + return false + } } panic(fmt.Errorf("unimplemented statement kind: %s", statement.Kind())) } diff --git a/match_test.go b/match_test.go index 59d9141..10049a9 100644 --- a/match_test.go +++ b/match_test.go @@ -6,6 +6,7 @@ import ( "github.com/gobwas/glob" "github.com/ipfs/go-cid" + "github.com/ipld/go-ipld-prime" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" "github.com/storacha-network/go-ucanto/core/policy/literal" @@ -331,4 +332,77 @@ func TestMatch(t *testing.T) { }(s) } }) + + buildValueNode := func(v int64) ipld.Node { + np := basicnode.Prototype.Map + nb := np.NewBuilder() + ma, _ := nb.BeginMap(1) + ma.AssembleKey().AssignString("value") + ma.AssembleValue().AssignInt(v) + ma.Finish() + return nb.Build() + } + + t.Run("quantification all", func(t *testing.T) { + np := basicnode.Prototype.List + nb := np.NewBuilder() + la, _ := nb.BeginList(5) + la.AssembleValue().AssignNode(buildValueNode(5)) + la.AssembleValue().AssignNode(buildValueNode(10)) + la.AssembleValue().AssignNode(buildValueNode(20)) + la.AssembleValue().AssignNode(buildValueNode(50)) + la.AssembleValue().AssignNode(buildValueNode(100)) + la.Finish() + nd := nb.Build() + + pol := Policy{ + All( + selector.MustParse(".[]"), + GreaterThan(selector.MustParse(".value"), literal.Int(2)), + ), + } + ok := Match(pol, nd) + require.True(t, ok) + + pol = Policy{ + All( + selector.MustParse(".[]"), + GreaterThan(selector.MustParse(".value"), literal.Int(20)), + ), + } + ok = Match(pol, nd) + require.False(t, ok) + }) + + t.Run("quantification any", func(t *testing.T) { + np := basicnode.Prototype.List + nb := np.NewBuilder() + la, _ := nb.BeginList(5) + la.AssembleValue().AssignNode(buildValueNode(5)) + la.AssembleValue().AssignNode(buildValueNode(10)) + la.AssembleValue().AssignNode(buildValueNode(20)) + la.AssembleValue().AssignNode(buildValueNode(50)) + la.AssembleValue().AssignNode(buildValueNode(100)) + la.Finish() + nd := nb.Build() + + pol := Policy{ + Any( + selector.MustParse(".[]"), + GreaterThan(selector.MustParse(".value"), literal.Int(10)), + LessThan(selector.MustParse(".value"), literal.Int(50)), + ), + } + ok := Match(pol, nd) + require.True(t, ok) + + pol = Policy{ + Any( + selector.MustParse(".[]"), + GreaterThan(selector.MustParse(".value"), literal.Int(100)), + ), + } + ok = Match(pol, nd) + require.False(t, ok) + }) } diff --git a/policy.go b/policy.go index b03bf5a..49f1f48 100644 --- a/policy.go +++ b/policy.go @@ -196,10 +196,10 @@ func (n quantifier) Value() Policy { return n.policy } -func All(selector selector.Selector, policy Policy) QuantifierStatement { +func All(selector selector.Selector, policy ...Statement) QuantifierStatement { return quantifier{Kind_All, selector, policy} } -func Any(selector selector.Selector, policy Policy) QuantifierStatement { +func Any(selector selector.Selector, policy ...Statement) QuantifierStatement { return quantifier{Kind_Any, selector, policy} }