feat(policy): filter statements subset by selector
This commit is contained in:
@@ -22,6 +22,29 @@ func Match(policy Policy, node ipld.Node) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Filter extracts a subset of the policy related to the specified selector.
|
||||
func (p Policy) Filter(sel selector.Selector) Policy {
|
||||
var filtered Policy
|
||||
for _, stmt := range p {
|
||||
if stmt.Selector().Matches(sel) {
|
||||
filtered = append(filtered, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// Match determines if the IPLD node matches the policy document.
|
||||
func (p Policy) Match(node datamodel.Node) bool {
|
||||
for _, stmt := range p {
|
||||
ok := matchStatement(stmt, node)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func matchStatement(statement Statement, node ipld.Node) bool {
|
||||
switch statement.Kind() {
|
||||
case KindEqual:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ipfs/go-cid"
|
||||
@@ -9,6 +11,7 @@ import (
|
||||
"github.com/ipld/go-ipld-prime/codec/dagjson"
|
||||
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
|
||||
"github.com/ipld/go-ipld-prime/node/basicnode"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/capability/policy/literal"
|
||||
@@ -539,3 +542,86 @@ func FuzzMatch(f *testing.F) {
|
||||
Match(policy, dataNode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicyFilter(t *testing.T) {
|
||||
sel1 := selector.Selector{selector.NewFieldSegment("http")}
|
||||
sel2 := selector.Selector{selector.NewFieldSegment("jsonrpc")}
|
||||
|
||||
stmt1 := Equal(sel1, basicnode.NewString("value1"))
|
||||
stmt2 := Equal(sel2, basicnode.NewString("value2"))
|
||||
|
||||
p := Policy{stmt1, stmt2}
|
||||
|
||||
filtered := p.Filter(sel1)
|
||||
assert.Len(t, filtered, 1)
|
||||
assert.Equal(t, stmt1, filtered[0])
|
||||
|
||||
filtered = p.Filter(sel2)
|
||||
assert.Len(t, filtered, 1)
|
||||
assert.Equal(t, stmt2, filtered[0])
|
||||
|
||||
sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")}
|
||||
filtered = p.Filter(sel3)
|
||||
assert.Len(t, filtered, 0)
|
||||
}
|
||||
|
||||
func FuzzPolicyFilter(f *testing.F) {
|
||||
f.Add([]byte(`{"selector": [{"field": "http"}], "value": "value1"}`))
|
||||
f.Add([]byte(`{"selector": [{"field": "jsonrpc"}], "value": "value2"}`))
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
var input struct {
|
||||
Selector []struct {
|
||||
Field string `json:"field"` // because selector.segment is not public
|
||||
} `json:"selector"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &input); err != nil {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
var sel selector.Selector
|
||||
for _, seg := range input.Selector {
|
||||
sel = append(sel, selector.NewFieldSegment(seg.Field))
|
||||
}
|
||||
stmt := Equal(sel, basicnode.NewString(input.Value))
|
||||
|
||||
// create a policy and filter it based on the fuzzy input selector
|
||||
p := Policy{stmt}
|
||||
filtered := p.Filter(sel)
|
||||
|
||||
// verify that the filtered policy contains the statement
|
||||
if len(filtered) != 1 || !reflect.DeepEqual(filtered[0], stmt) {
|
||||
t.Errorf("filtered policy does not contain the expected statement")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkPolicyFilter(b *testing.B) {
|
||||
sel1 := selector.Selector{selector.NewFieldSegment("http")}
|
||||
sel2 := selector.Selector{selector.NewFieldSegment("jsonrpc")}
|
||||
|
||||
stmt1 := Equal(sel1, basicnode.NewString("value1"))
|
||||
stmt2 := Equal(sel2, basicnode.NewString("value2"))
|
||||
|
||||
p := Policy{stmt1, stmt2}
|
||||
|
||||
b.Run("Filter by sel1", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.Filter(sel1)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Filter by sel2", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.Filter(sel2)
|
||||
}
|
||||
})
|
||||
|
||||
sel3 := selector.Selector{selector.NewFieldSegment("nonexistent")}
|
||||
b.Run("Filter by sel3", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.Filter(sel3)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ type Policy []Statement
|
||||
|
||||
type Statement interface {
|
||||
Kind() string
|
||||
Selector() selector.Selector
|
||||
}
|
||||
|
||||
type equality struct {
|
||||
@@ -38,6 +39,10 @@ func (e equality) Kind() string {
|
||||
return e.kind
|
||||
}
|
||||
|
||||
func (e equality) Selector() selector.Selector {
|
||||
return e.selector
|
||||
}
|
||||
|
||||
func Equal(selector selector.Selector, value ipld.Node) Statement {
|
||||
return equality{kind: KindEqual, selector: selector, value: value}
|
||||
}
|
||||
@@ -66,6 +71,10 @@ func (n negation) Kind() string {
|
||||
return KindNot
|
||||
}
|
||||
|
||||
func (n negation) Selector() selector.Selector {
|
||||
return n.statement.Selector()
|
||||
}
|
||||
|
||||
func Not(stmt Statement) Statement {
|
||||
return negation{statement: stmt}
|
||||
}
|
||||
@@ -79,6 +88,15 @@ func (c connective) Kind() string {
|
||||
return c.kind
|
||||
}
|
||||
|
||||
func (c connective) Selector() selector.Selector {
|
||||
// assuming the first statement's selector is representative
|
||||
if len(c.statements) > 0 {
|
||||
return c.statements[0].Selector()
|
||||
}
|
||||
|
||||
return selector.Selector{}
|
||||
}
|
||||
|
||||
func And(stmts ...Statement) Statement {
|
||||
return connective{kind: KindAnd, statements: stmts}
|
||||
}
|
||||
@@ -96,6 +114,10 @@ func (n wildcard) Kind() string {
|
||||
return KindLike
|
||||
}
|
||||
|
||||
func (n wildcard) Selector() selector.Selector {
|
||||
return n.selector
|
||||
}
|
||||
|
||||
func Like(selector selector.Selector, pattern string) (Statement, error) {
|
||||
g, err := parseGlob(pattern)
|
||||
if err != nil {
|
||||
@@ -115,6 +137,10 @@ func (n quantifier) Kind() string {
|
||||
return n.kind
|
||||
}
|
||||
|
||||
func (n quantifier) Selector() selector.Selector {
|
||||
return n.selector
|
||||
}
|
||||
|
||||
func All(selector selector.Selector, statement Statement) Statement {
|
||||
return quantifier{kind: KindAll, selector: selector, statement: statement}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,21 @@ func (s Selector) String() string {
|
||||
return res.String()
|
||||
}
|
||||
|
||||
// Matches checks if the selector matches another selector.
|
||||
func (s Selector) Matches(other Selector) bool {
|
||||
if len(s) != len(other) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, seg := range s {
|
||||
if seg.str != other[i].str {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
var Identity = segment{".", true, false, false, nil, "", 0}
|
||||
|
||||
var (
|
||||
@@ -41,6 +56,38 @@ type segment struct {
|
||||
index int
|
||||
}
|
||||
|
||||
// NewFieldSegment creates a new segment for a field.
|
||||
func NewFieldSegment(field string) segment {
|
||||
return segment{
|
||||
str: fmt.Sprintf(".%s", field),
|
||||
field: field,
|
||||
}
|
||||
}
|
||||
|
||||
// NewIndexSegment creates a new segment for an index.
|
||||
func NewIndexSegment(index int) segment {
|
||||
return segment{
|
||||
str: fmt.Sprintf("[%d]", index),
|
||||
index: index,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSliceSegment creates a new segment for a slice.
|
||||
func NewSliceSegment(slice []int) segment {
|
||||
return segment{
|
||||
str: fmt.Sprintf("[%d:%d]", slice[0], slice[1]),
|
||||
slice: slice,
|
||||
}
|
||||
}
|
||||
|
||||
// NewIteratorSegment creates a new segment for an iterator.
|
||||
func NewIteratorSegment() segment {
|
||||
return segment{
|
||||
str: "*",
|
||||
iterator: true,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the segment's string representation.
|
||||
func (s segment) String() string {
|
||||
return s.str
|
||||
|
||||
Reference in New Issue
Block a user