versatile segments matchers

This commit is contained in:
Fabio Bozzo
2024-09-24 19:36:01 +02:00
parent 18820f5e9d
commit d66b8e40ec
3 changed files with 34 additions and 4 deletions

View File

@@ -11,11 +11,16 @@ import (
"github.com/ucan-wg/go-ucan/capability/policy/selector" "github.com/ucan-wg/go-ucan/capability/policy/selector"
) )
// Filter extracts a subset of the policy related to the specified selector.
func (p Policy) Filter(sel selector.Selector) Policy { func (p Policy) Filter(sel selector.Selector) Policy {
return p.FilterWithMatcher(sel, selector.SegmentEquals)
}
// FilterWithMatcher extracts a subset of the policy related to the specified selector,
// by matching each segment using the given selector.SegmentMatcher.
func (p Policy) FilterWithMatcher(sel selector.Selector, matcher selector.SegmentMatcher) Policy {
var filtered Policy var filtered Policy
for _, stmt := range p { for _, stmt := range p {
if stmt.Selector().Matches(sel) { if stmt.Selector().Matches(sel, matcher) {
filtered = append(filtered, stmt) filtered = append(filtered, stmt)
} }
} }

View File

@@ -563,6 +563,21 @@ func TestPolicyFilter(t *testing.T) {
sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")} sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")}
filtered = p.Filter(sel3) filtered = p.Filter(sel3)
assert.Len(t, filtered, 0) assert.Len(t, filtered, 0)
stmt1 = Equal(
selector.Selector{selector.NewFieldSegment(".http.host")},
basicnode.NewString("mainnet.infura.io"),
)
stmt2, err := Like(
selector.Selector{selector.NewFieldSegment(".jsonrpc.method")},
"eth_*",
)
require.NoError(t, err)
p = Policy{stmt1, stmt2}
filtered = p.FilterWithMatcher(selector.Selector{selector.NewFieldSegment(".http")}, selector.SegmentStartsWith)
assert.Len(t, filtered, 1)
assert.Equal(t, stmt1, filtered[0])
} }
func FuzzPolicyFilter(f *testing.F) { func FuzzPolicyFilter(f *testing.F) {

View File

@@ -24,13 +24,13 @@ func (s Selector) String() string {
} }
// Matches checks if the selector matches another selector. // Matches checks if the selector matches another selector.
func (s Selector) Matches(other Selector) bool { func (s Selector) Matches(other Selector, matcher SegmentMatcher) bool {
if len(s) != len(other) { if len(s) != len(other) {
return false return false
} }
for i, seg := range s { for i, seg := range s {
if seg.str != other[i].str { if !matcher(seg, other[i]) {
return false return false
} }
} }
@@ -88,6 +88,16 @@ func NewIteratorSegment() segment {
} }
} }
type SegmentMatcher func(s1, s2 segment) bool
var SegmentEquals SegmentMatcher = func(s1, s2 segment) bool {
return s1.str == s2.str
}
var SegmentStartsWith SegmentMatcher = func(s1, s2 segment) bool {
return strings.HasPrefix(s1.str, s2.str)
}
// String returns the segment's string representation. // String returns the segment's string representation.
func (s segment) String() string { func (s segment) String() string {
return s.str return s.str