From ddaa67ed7dba98b0b3ec014408da198afd6e2815 Mon Sep 17 00:00:00 2001 From: Fabio Bozzo Date: Tue, 24 Sep 2024 11:33:06 +0200 Subject: [PATCH] feat(policy): filter statements subset by selector --- capability/policy/match.go | 23 +++++++ capability/policy/match_test.go | 86 ++++++++++++++++++++++++++ capability/policy/policy.go | 26 ++++++++ capability/policy/selector/selector.go | 47 ++++++++++++++ 4 files changed, 182 insertions(+) diff --git a/capability/policy/match.go b/capability/policy/match.go index 5313af4..079d63d 100644 --- a/capability/policy/match.go +++ b/capability/policy/match.go @@ -22,6 +22,29 @@ func Match(policy Policy, node ipld.Node) bool { return true } +// Filter extracts a subset of the policy related to the specified selector. +func (p Policy) Filter(sel selector.Selector) Policy { + var filtered Policy + for _, stmt := range p { + if stmt.Selector().Matches(sel) { + filtered = append(filtered, stmt) + } + } + + return filtered +} + +// Match determines if the IPLD node matches the policy document. +func (p Policy) Match(node datamodel.Node) bool { + for _, stmt := range p { + ok := matchStatement(stmt, node) + if !ok { + return false + } + } + return true +} + func matchStatement(statement Statement, node ipld.Node) bool { switch statement.Kind() { case KindEqual: diff --git a/capability/policy/match_test.go b/capability/policy/match_test.go index 4f7dc9b..d047b16 100644 --- a/capability/policy/match_test.go +++ b/capability/policy/match_test.go @@ -1,7 +1,9 @@ package policy import ( + "encoding/json" "fmt" + "reflect" "testing" "github.com/ipfs/go-cid" @@ -9,6 +11,7 @@ import ( "github.com/ipld/go-ipld-prime/codec/dagjson" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ucan-wg/go-ucan/capability/policy/literal" @@ -539,3 +542,86 @@ func FuzzMatch(f *testing.F) { Match(policy, dataNode) }) } + +func TestPolicyFilter(t *testing.T) { + sel1 := selector.Selector{selector.NewFieldSegment("http")} + sel2 := selector.Selector{selector.NewFieldSegment("jsonrpc")} + + stmt1 := Equal(sel1, basicnode.NewString("value1")) + stmt2 := Equal(sel2, basicnode.NewString("value2")) + + p := Policy{stmt1, stmt2} + + filtered := p.Filter(sel1) + assert.Len(t, filtered, 1) + assert.Equal(t, stmt1, filtered[0]) + + filtered = p.Filter(sel2) + assert.Len(t, filtered, 1) + assert.Equal(t, stmt2, filtered[0]) + + sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")} + filtered = p.Filter(sel3) + assert.Len(t, filtered, 0) +} + +func FuzzPolicyFilter(f *testing.F) { + f.Add([]byte(`{"selector": [{"field": "http"}], "value": "value1"}`)) + f.Add([]byte(`{"selector": [{"field": "jsonrpc"}], "value": "value2"}`)) + + f.Fuzz(func(t *testing.T, data []byte) { + var input struct { + Selector []struct { + Field string `json:"field"` // because selector.segment is not public + } `json:"selector"` + Value string `json:"value"` + } + if err := json.Unmarshal(data, &input); err != nil { + t.Skip() + } + + var sel selector.Selector + for _, seg := range input.Selector { + sel = append(sel, selector.NewFieldSegment(seg.Field)) + } + stmt := Equal(sel, basicnode.NewString(input.Value)) + + // create a policy and filter it based on the fuzzy input selector + p := Policy{stmt} + filtered := p.Filter(sel) + + // verify that the filtered policy contains the statement + if len(filtered) != 1 || !reflect.DeepEqual(filtered[0], stmt) { + t.Errorf("filtered policy does not contain the expected statement") + } + }) +} + +func BenchmarkPolicyFilter(b *testing.B) { + sel1 := selector.Selector{selector.NewFieldSegment("http")} + sel2 := selector.Selector{selector.NewFieldSegment("jsonrpc")} + + stmt1 := Equal(sel1, basicnode.NewString("value1")) + stmt2 := Equal(sel2, basicnode.NewString("value2")) + + p := Policy{stmt1, stmt2} + + b.Run("Filter by sel1", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.Filter(sel1) + } + }) + + b.Run("Filter by sel2", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.Filter(sel2) + } + }) + + sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")} + b.Run("Filter by sel3", func(b *testing.B) { + for i := 0; i < b.N; i++ { + p.Filter(sel3) + } + }) +} diff --git a/capability/policy/policy.go b/capability/policy/policy.go index e6f385e..7c6dc7d 100644 --- a/capability/policy/policy.go +++ b/capability/policy/policy.go @@ -26,6 +26,7 @@ type Policy []Statement type Statement interface { Kind() string + Selector() selector.Selector } type equality struct { @@ -38,6 +39,10 @@ func (e equality) Kind() string { return e.kind } +func (e equality) Selector() selector.Selector { + return e.selector +} + func Equal(selector selector.Selector, value ipld.Node) Statement { return equality{kind: KindEqual, selector: selector, value: value} } @@ -66,6 +71,10 @@ func (n negation) Kind() string { return KindNot } +func (n negation) Selector() selector.Selector { + return n.statement.Selector() +} + func Not(stmt Statement) Statement { return negation{statement: stmt} } @@ -79,6 +88,15 @@ func (c connective) Kind() string { return c.kind } +func (c connective) Selector() selector.Selector { + // assuming the first statement's selector is representative + if len(c.statements) > 0 { + return c.statements[0].Selector() + } + + return selector.Selector{} +} + func And(stmts ...Statement) Statement { return connective{kind: KindAnd, statements: stmts} } @@ -96,6 +114,10 @@ func (n wildcard) Kind() string { return KindLike } +func (n wildcard) Selector() selector.Selector { + return n.selector +} + func Like(selector selector.Selector, pattern string) (Statement, error) { g, err := parseGlob(pattern) if err != nil { @@ -115,6 +137,10 @@ func (n quantifier) Kind() string { return n.kind } +func (n quantifier) Selector() selector.Selector { + return n.selector +} + func All(selector selector.Selector, statement Statement) Statement { return quantifier{kind: KindAll, selector: selector, statement: statement} } diff --git a/capability/policy/selector/selector.go b/capability/policy/selector/selector.go index 1d57718..ed2a0f8 100644 --- a/capability/policy/selector/selector.go +++ b/capability/policy/selector/selector.go @@ -23,6 +23,21 @@ func (s Selector) String() string { return res.String() } +// Matches checks if the selector matches another selector. +func (s Selector) Matches(other Selector) bool { + if len(s) != len(other) { + return false + } + + for i, seg := range s { + if seg.str != other[i].str { + return false + } + } + + return true +} + var Identity = segment{".", true, false, false, nil, "", 0} var ( @@ -41,6 +56,38 @@ type segment struct { index int } +// NewFieldSegment creates a new segment for a field. +func NewFieldSegment(field string) segment { + return segment{ + str: fmt.Sprintf(".%s", field), + field: field, + } +} + +// NewIndexSegment creates a new segment for an index. +func NewIndexSegment(index int) segment { + return segment{ + str: fmt.Sprintf("[%d]", index), + index: index, + } +} + +// NewSliceSegment creates a new segment for a slice. +func NewSliceSegment(slice []int) segment { + return segment{ + str: fmt.Sprintf("[%d:%d]", slice[0], slice[1]), + slice: slice, + } +} + +// NewIteratorSegment creates a new segment for an iterator. +func NewIteratorSegment() segment { + return segment{ + str: "*", + iterator: true, + } +} + // String returns the segment's string representation. func (s segment) String() string { return s.str