diff --git a/capability/policy/match.go b/capability/policy/match.go index d6aaa7a..6da07ca 100644 --- a/capability/policy/match.go +++ b/capability/policy/match.go @@ -11,11 +11,16 @@ import ( "github.com/ucan-wg/go-ucan/capability/policy/selector" ) -// Filter extracts a subset of the policy related to the specified selector. func (p Policy) Filter(sel selector.Selector) Policy { + return p.FilterWithMatcher(sel, selector.SegmentEquals) +} + +// FilterWithMatcher extracts a subset of the policy related to the specified selector, +// by matching each segment using the given selector.SegmentMatcher. +func (p Policy) FilterWithMatcher(sel selector.Selector, matcher selector.SegmentMatcher) Policy { var filtered Policy for _, stmt := range p { - if stmt.Selector().Matches(sel) { + if stmt.Selector().Matches(sel, matcher) { filtered = append(filtered, stmt) } } diff --git a/capability/policy/match_test.go b/capability/policy/match_test.go index 549e211..7ea0bcf 100644 --- a/capability/policy/match_test.go +++ b/capability/policy/match_test.go @@ -563,6 +563,21 @@ func TestPolicyFilter(t *testing.T) { sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")} filtered = p.Filter(sel3) assert.Len(t, filtered, 0) + + stmt1 = Equal( + selector.Selector{selector.NewFieldSegment(".http.host")}, + basicnode.NewString("mainnet.infura.io"), + ) + stmt2, err := Like( + selector.Selector{selector.NewFieldSegment(".jsonrpc.method")}, + "eth_*", + ) + require.NoError(t, err) + + p = Policy{stmt1, stmt2} + filtered = p.FilterWithMatcher(selector.Selector{selector.NewFieldSegment(".http")}, selector.SegmentStartsWith) + assert.Len(t, filtered, 1) + assert.Equal(t, stmt1, filtered[0]) } func FuzzPolicyFilter(f *testing.F) { diff --git a/capability/policy/selector/selector.go b/capability/policy/selector/selector.go index ed2a0f8..9437d59 100644 --- a/capability/policy/selector/selector.go +++ b/capability/policy/selector/selector.go @@ -24,13 +24,13 @@ func (s Selector) String() string { } // Matches checks if the selector matches another selector. -func (s Selector) Matches(other Selector) bool { +func (s Selector) Matches(other Selector, matcher SegmentMatcher) bool { if len(s) != len(other) { return false } for i, seg := range s { - if seg.str != other[i].str { + if !matcher(seg, other[i]) { return false } } @@ -88,6 +88,16 @@ func NewIteratorSegment() segment { } } +type SegmentMatcher func(s1, s2 segment) bool + +var SegmentEquals SegmentMatcher = func(s1, s2 segment) bool { + return s1.str == s2.str +} + +var SegmentStartsWith SegmentMatcher = func(s1, s2 segment) bool { + return strings.HasPrefix(s1.str, s2.str) +} + // String returns the segment's string representation. func (s segment) String() string { return s.str