4 Commits

Author SHA1 Message Date
Fabio Bozzo
bfd5f00618 Merge branch 'v1' into v1-policy-subset-selection 2024-09-24 19:41:39 +02:00
Fabio Bozzo
d66b8e40ec versatile segments matchers 2024-09-24 19:36:01 +02:00
Fabio Bozzo
18820f5e9d refactor policy.Match as receiver 2024-09-24 11:36:31 +02:00
Fabio Bozzo
ddaa67ed7d feat(policy): filter statements subset by selector 2024-09-24 11:33:06 +02:00
4 changed files with 247 additions and 45 deletions

View File

@@ -11,14 +11,32 @@ import (
"github.com/ucan-wg/go-ucan/pkg/policy/selector"
)
func (p Policy) Filter(sel selector.Selector) Policy {
return p.FilterWithMatcher(sel, selector.SegmentEquals)
}
// FilterWithMatcher extracts a subset of the policy related to the specified selector,
// by matching each segment using the given selector.SegmentMatcher.
func (p Policy) FilterWithMatcher(sel selector.Selector, matcher selector.SegmentMatcher) Policy {
var filtered Policy
for _, stmt := range p {
if stmt.Selector().Matches(sel, matcher) {
filtered = append(filtered, stmt)
}
}
return filtered
}
// Match determines if the IPLD node matches the policy document.
func Match(policy Policy, node ipld.Node) bool {
for _, stmt := range policy {
func (p Policy) Match(node datamodel.Node) bool {
for _, stmt := range p {
ok := matchStatement(stmt, node)
if !ok {
return false
}
}
return true
}

View File

@@ -1,7 +1,9 @@
package policy
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"github.com/ipfs/go-cid"
@@ -9,6 +11,7 @@ import (
"github.com/ipld/go-ipld-prime/codec/dagjson"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
@@ -24,15 +27,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.String("test"))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("test2"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Int(138))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -43,15 +46,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.Int(138))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Int(1138))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("138"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -62,15 +65,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.Float(1.138))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Float(11.38))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("138"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -84,15 +87,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.Link(l0))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Link(l1))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -106,19 +109,19 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse(".foo"), literal.String("bar"))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse(".[\"foo\"]"), literal.String("bar"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse(".foo"), literal.String("baz"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse(".foobar"), literal.String("bar"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -131,11 +134,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse(".[0]"), literal.String("foo"))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse(".[1]"), literal.String("foo"))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
})
@@ -148,7 +151,7 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThan(selector.MustParse("."), literal.Int(1))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
})
@@ -159,11 +162,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(1))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(138))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.True(t, ok)
})
@@ -174,7 +177,7 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThan(selector.MustParse("."), literal.Float(1))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
})
@@ -185,11 +188,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1.38))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.True(t, ok)
})
@@ -200,7 +203,7 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{LessThan(selector.MustParse("."), literal.Int(1138))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
})
@@ -211,11 +214,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(1138))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(138))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.True(t, ok)
})
})
@@ -227,11 +230,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Not(Equal(selector.MustParse("."), literal.Bool(true)))}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{Not(Equal(selector.MustParse("."), literal.Bool(false)))}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -247,7 +250,7 @@ func TestMatch(t *testing.T) {
LessThan(selector.MustParse("."), literal.Int(1138)),
),
}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{
@@ -256,11 +259,11 @@ func TestMatch(t *testing.T) {
Equal(selector.MustParse("."), literal.Int(1138)),
),
}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{And()}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.True(t, ok)
})
@@ -276,7 +279,7 @@ func TestMatch(t *testing.T) {
LessThan(selector.MustParse("."), literal.Int(1138)),
),
}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{
@@ -285,11 +288,11 @@ func TestMatch(t *testing.T) {
Equal(selector.MustParse("."), literal.Int(1138)),
),
}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
pol = Policy{Or()}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.True(t, ok)
})
@@ -313,7 +316,7 @@ func TestMatch(t *testing.T) {
require.NoError(t, err)
pol := Policy{statement}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
})
}(s)
@@ -337,7 +340,7 @@ func TestMatch(t *testing.T) {
require.NoError(t, err)
pol := Policy{statement}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.False(t, ok)
})
}(s)
@@ -373,7 +376,7 @@ func TestMatch(t *testing.T) {
GreaterThan(selector.MustParse(".value"), literal.Int(2)),
),
}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{
@@ -382,7 +385,7 @@ func TestMatch(t *testing.T) {
GreaterThan(selector.MustParse(".value"), literal.Int(20)),
),
}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
@@ -404,7 +407,7 @@ func TestMatch(t *testing.T) {
GreaterThan(selector.MustParse(".value"), literal.Int(60)),
),
}
ok := Match(pol, nd)
ok := pol.Match(nd)
require.True(t, ok)
pol = Policy{
@@ -413,7 +416,7 @@ func TestMatch(t *testing.T) {
GreaterThan(selector.MustParse(".value"), literal.Int(100)),
),
}
ok = Match(pol, nd)
ok = pol.Match(nd)
require.False(t, ok)
})
})
@@ -432,7 +435,7 @@ func TestPolicyExamples(t *testing.T) {
pol, err := FromDagJson(policy)
require.NoError(t, err)
return Match(pol, data)
return pol.Match(data)
}
t.Run("And", func(t *testing.T) {
@@ -536,6 +539,104 @@ func FuzzMatch(f *testing.F) {
t.Skip()
}
Match(policy, dataNode)
policy.Match(dataNode)
})
}
func TestPolicyFilter(t *testing.T) {
sel1 := selector.Selector{selector.NewFieldSegment("http")}
sel2 := selector.Selector{selector.NewFieldSegment("jsonrpc")}
stmt1 := Equal(sel1, basicnode.NewString("value1"))
stmt2 := Equal(sel2, basicnode.NewString("value2"))
p := Policy{stmt1, stmt2}
filtered := p.Filter(sel1)
assert.Len(t, filtered, 1)
assert.Equal(t, stmt1, filtered[0])
filtered = p.Filter(sel2)
assert.Len(t, filtered, 1)
assert.Equal(t, stmt2, filtered[0])
sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")}
filtered = p.Filter(sel3)
assert.Len(t, filtered, 0)
stmt1 = Equal(
selector.Selector{selector.NewFieldSegment(".http.host")},
basicnode.NewString("mainnet.infura.io"),
)
stmt2, err := Like(
selector.Selector{selector.NewFieldSegment(".jsonrpc.method")},
"eth_*",
)
require.NoError(t, err)
p = Policy{stmt1, stmt2}
filtered = p.FilterWithMatcher(selector.Selector{selector.NewFieldSegment(".http")}, selector.SegmentStartsWith)
assert.Len(t, filtered, 1)
assert.Equal(t, stmt1, filtered[0])
}
func FuzzPolicyFilter(f *testing.F) {
f.Add([]byte(`{"selector": [{"field": "http"}], "value": "value1"}`))
f.Add([]byte(`{"selector": [{"field": "jsonrpc"}], "value": "value2"}`))
f.Fuzz(func(t *testing.T, data []byte) {
var input struct {
Selector []struct {
Field string `json:"field"` // because selector.segment is not public
} `json:"selector"`
Value string `json:"value"`
}
if err := json.Unmarshal(data, &input); err != nil {
t.Skip()
}
var sel selector.Selector
for _, seg := range input.Selector {
sel = append(sel, selector.NewFieldSegment(seg.Field))
}
stmt := Equal(sel, basicnode.NewString(input.Value))
// create a policy and filter it based on the fuzzy input selector
p := Policy{stmt}
filtered := p.Filter(sel)
// verify that the filtered policy contains the statement
if len(filtered) != 1 || !reflect.DeepEqual(filtered[0], stmt) {
t.Errorf("filtered policy does not contain the expected statement")
}
})
}
func BenchmarkPolicyFilter(b *testing.B) {
sel1 := selector.Selector{selector.NewFieldSegment("http")}
sel2 := selector.Selector{selector.NewFieldSegment("jsonrpc")}
stmt1 := Equal(sel1, basicnode.NewString("value1"))
stmt2 := Equal(sel2, basicnode.NewString("value2"))
p := Policy{stmt1, stmt2}
b.Run("Filter by sel1", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.Filter(sel1)
}
})
b.Run("Filter by sel2", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.Filter(sel2)
}
})
sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")}
b.Run("Filter by sel3", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.Filter(sel3)
}
})
}

View File

@@ -26,6 +26,7 @@ type Policy []Statement
type Statement interface {
Kind() string
Selector() selector.Selector
}
type equality struct {
@@ -38,6 +39,10 @@ func (e equality) Kind() string {
return e.kind
}
func (e equality) Selector() selector.Selector {
return e.selector
}
func Equal(selector selector.Selector, value ipld.Node) Statement {
return equality{kind: KindEqual, selector: selector, value: value}
}
@@ -66,6 +71,10 @@ func (n negation) Kind() string {
return KindNot
}
func (n negation) Selector() selector.Selector {
return n.statement.Selector()
}
func Not(stmt Statement) Statement {
return negation{statement: stmt}
}
@@ -79,6 +88,15 @@ func (c connective) Kind() string {
return c.kind
}
func (c connective) Selector() selector.Selector {
// assuming the first statement's selector is representative
if len(c.statements) > 0 {
return c.statements[0].Selector()
}
return selector.Selector{}
}
func And(stmts ...Statement) Statement {
return connective{kind: KindAnd, statements: stmts}
}
@@ -96,6 +114,10 @@ func (n wildcard) Kind() string {
return KindLike
}
func (n wildcard) Selector() selector.Selector {
return n.selector
}
func Like(selector selector.Selector, pattern string) (Statement, error) {
g, err := parseGlob(pattern)
if err != nil {
@@ -115,6 +137,10 @@ func (n quantifier) Kind() string {
return n.kind
}
func (n quantifier) Selector() selector.Selector {
return n.selector
}
func All(selector selector.Selector, statement Statement) Statement {
return quantifier{kind: KindAll, selector: selector, statement: statement}
}

View File

@@ -23,6 +23,21 @@ func (s Selector) String() string {
return res.String()
}
// Matches checks if the selector matches another selector.
func (s Selector) Matches(other Selector, matcher SegmentMatcher) bool {
if len(s) != len(other) {
return false
}
for i, seg := range s {
if !matcher(seg, other[i]) {
return false
}
}
return true
}
var Identity = segment{".", true, false, false, nil, "", 0}
var (
@@ -41,6 +56,48 @@ type segment struct {
index int
}
// NewFieldSegment creates a new segment for a field.
func NewFieldSegment(field string) segment {
return segment{
str: fmt.Sprintf(".%s", field),
field: field,
}
}
// NewIndexSegment creates a new segment for an index.
func NewIndexSegment(index int) segment {
return segment{
str: fmt.Sprintf("[%d]", index),
index: index,
}
}
// NewSliceSegment creates a new segment for a slice.
func NewSliceSegment(slice []int) segment {
return segment{
str: fmt.Sprintf("[%d:%d]", slice[0], slice[1]),
slice: slice,
}
}
// NewIteratorSegment creates a new segment for an iterator.
func NewIteratorSegment() segment {
return segment{
str: "*",
iterator: true,
}
}
type SegmentMatcher func(s1, s2 segment) bool
var SegmentEquals SegmentMatcher = func(s1, s2 segment) bool {
return s1.str == s2.str
}
var SegmentStartsWith SegmentMatcher = func(s1, s2 segment) bool {
return strings.HasPrefix(s1.str, s2.str)
}
// String returns the segment's string representation.
func (s segment) String() string {
return s.str