Merge pull request #44 from ucan-wg/policy-filtering
policy: add a way to filter policies with a path
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
// Package literal holds a collection of functions to create IPLD types to use in policies, selector and args.
|
||||
package literal
|
||||
|
||||
import (
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/ipld/go-ipld-prime/must"
|
||||
)
|
||||
|
||||
// Match determines if the IPLD node matches the policy document.
|
||||
// Match determines if the IPLD node satisfies the policy.
|
||||
func (p Policy) Match(node datamodel.Node) bool {
|
||||
for _, stmt := range p {
|
||||
ok := matchStatement(stmt, node)
|
||||
@@ -20,6 +20,20 @@ func (p Policy) Match(node datamodel.Node) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Filter performs a recursive filtering of the Statement, and prunes what doesn't match the given path
|
||||
func (p Policy) Filter(path ...string) Policy {
|
||||
var filtered Policy
|
||||
|
||||
for _, stmt := range p {
|
||||
newChild, remain := filter(stmt, path)
|
||||
if newChild != nil && len(remain) == 0 {
|
||||
filtered = append(filtered, newChild)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
func matchStatement(statement Statement, node ipld.Node) bool {
|
||||
switch statement.Kind() {
|
||||
case KindEqual:
|
||||
@@ -153,6 +167,70 @@ func matchStatement(statement Statement, node ipld.Node) bool {
|
||||
panic(fmt.Errorf("unimplemented statement kind: %s", statement.Kind()))
|
||||
}
|
||||
|
||||
// filter performs a recursive filtering of the Statement, and prunes what doesn't match the given path
|
||||
func filter(stmt Statement, path []string) (Statement, []string) {
|
||||
// For each kind, we do some of the following if it applies:
|
||||
// - test the path against the selector, consuming segments
|
||||
// - for terminal statements (equality, wildcard), require all the segments to have been consumed
|
||||
// - recursively filter child (negation, quantifier) or children (connective) statements with the remaining path
|
||||
switch stmt.(type) {
|
||||
case equality:
|
||||
match, remain := stmt.(equality).selector.MatchPath(path...)
|
||||
if match && len(remain) == 0 {
|
||||
return stmt, remain
|
||||
}
|
||||
return nil, nil
|
||||
case negation:
|
||||
newChild, remain := filter(stmt.(negation).statement, path)
|
||||
if newChild != nil && len(remain) == 0 {
|
||||
return negation{
|
||||
statement: newChild,
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
case connective:
|
||||
var newChildren []Statement
|
||||
for _, child := range stmt.(connective).statements {
|
||||
newChild, remain := filter(child, path)
|
||||
if newChild != nil && len(remain) == 0 {
|
||||
newChildren = append(newChildren, newChild)
|
||||
}
|
||||
}
|
||||
if len(newChildren) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return connective{
|
||||
kind: stmt.(connective).kind,
|
||||
statements: newChildren,
|
||||
}, nil
|
||||
case wildcard:
|
||||
match, remain := stmt.(wildcard).selector.MatchPath(path...)
|
||||
if match && len(remain) == 0 {
|
||||
return stmt, remain
|
||||
}
|
||||
return nil, nil
|
||||
case quantifier:
|
||||
match, remain := stmt.(quantifier).selector.MatchPath(path...)
|
||||
if match && len(remain) == 0 {
|
||||
return stmt, remain
|
||||
}
|
||||
if !match {
|
||||
return nil, nil
|
||||
}
|
||||
newChild, remain := filter(stmt.(quantifier).statement, remain)
|
||||
if newChild != nil && len(remain) == 0 {
|
||||
return quantifier{
|
||||
kind: stmt.(quantifier).kind,
|
||||
selector: stmt.(quantifier).selector,
|
||||
statement: newChild,
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
default:
|
||||
panic(fmt.Errorf("unimplemented statement kind: %s", stmt.Kind()))
|
||||
}
|
||||
}
|
||||
|
||||
func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool {
|
||||
if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int {
|
||||
a := must.Int(actual)
|
||||
|
||||
@@ -2,6 +2,7 @@ package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ipfs/go-cid"
|
||||
@@ -512,3 +513,103 @@ func FuzzMatch(f *testing.F) {
|
||||
policy.Match(dataNode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicyFilter(t *testing.T) {
|
||||
pol := MustConstruct(
|
||||
Any(".http", And(
|
||||
Equal(".method", literal.String("GET")),
|
||||
Equal(".path", literal.String("/foo")),
|
||||
)),
|
||||
Equal(".http", literal.String("foobar")),
|
||||
All(".jsonrpc.foo", Or(
|
||||
Not(Equal(".bar", literal.String("foo"))),
|
||||
Equal(".", literal.String("foo")),
|
||||
Like(".boo", "abcd"),
|
||||
Like(".boo", "*bcd"),
|
||||
)),
|
||||
)
|
||||
|
||||
for _, tc := range []struct {
|
||||
path string
|
||||
expected Policy
|
||||
}{
|
||||
{
|
||||
path: "http",
|
||||
expected: MustConstruct(
|
||||
Any(".http", And(
|
||||
Equal(".method", literal.String("GET")),
|
||||
Equal(".path", literal.String("/foo")),
|
||||
)),
|
||||
Equal(".http", literal.String("foobar")),
|
||||
),
|
||||
},
|
||||
{
|
||||
path: "http,method",
|
||||
expected: MustConstruct(
|
||||
Any(".http", And(
|
||||
Equal(".method", literal.String("GET")),
|
||||
)),
|
||||
),
|
||||
},
|
||||
{
|
||||
path: "http,path",
|
||||
expected: MustConstruct(
|
||||
Any(".http", And(
|
||||
Equal(".path", literal.String("/foo")),
|
||||
)),
|
||||
),
|
||||
},
|
||||
{
|
||||
path: "http,foo",
|
||||
expected: Policy{},
|
||||
},
|
||||
{
|
||||
path: "jsonrpc",
|
||||
expected: MustConstruct(
|
||||
All(".jsonrpc.foo", Or(
|
||||
Not(Equal(".bar", literal.String("foo"))),
|
||||
Equal(".", literal.String("foo")),
|
||||
Like(".boo", "abcd"),
|
||||
Like(".boo", "*bcd"),
|
||||
)),
|
||||
),
|
||||
},
|
||||
{
|
||||
path: "jsonrpc,baz",
|
||||
expected: Policy{},
|
||||
},
|
||||
{
|
||||
path: "jsonrpc,foo",
|
||||
expected: MustConstruct(
|
||||
All(".jsonrpc.foo", Or(
|
||||
Not(Equal(".bar", literal.String("foo"))),
|
||||
Equal(".", literal.String("foo")),
|
||||
Like(".boo", "abcd"),
|
||||
Like(".boo", "*bcd"),
|
||||
)),
|
||||
),
|
||||
},
|
||||
{
|
||||
path: "jsonrpc,foo,bar",
|
||||
expected: MustConstruct(
|
||||
All(".jsonrpc.foo", Or(
|
||||
Not(Equal(".bar", literal.String("foo"))),
|
||||
)),
|
||||
),
|
||||
},
|
||||
{
|
||||
path: "jsonrpc,foo,boo",
|
||||
expected: MustConstruct(
|
||||
All(".jsonrpc.foo", Or(
|
||||
Like(".boo", "abcd"),
|
||||
Like(".boo", "*bcd"),
|
||||
)),
|
||||
),
|
||||
},
|
||||
} {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
res := pol.Filter(strings.Split(tc.path, ",")...)
|
||||
require.Equal(t, tc.expected.String(), res.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,17 @@ package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var identity = Selector{segment{".", true, false, false, nil, "", 0}}
|
||||
var (
|
||||
identity = Selector{segment{str: ".", identity: true}}
|
||||
indexRegex = regexp.MustCompile(`^-?\d+$`)
|
||||
sliceRegex = regexp.MustCompile(`^((\-?\d+:\-?\d*)|(\-?\d*:\-?\d+))$`)
|
||||
fieldRegex = regexp.MustCompile(`^\.[a-zA-Z_]*?$`)
|
||||
)
|
||||
|
||||
func Parse(str string) (Selector, error) {
|
||||
if len(str) == 0 {
|
||||
@@ -32,9 +38,9 @@ func Parse(str string) (Selector, error) {
|
||||
if len(sel) > 0 && sel[len(sel)-1].Identity() {
|
||||
return nil, newParseError("selector contains unsupported recursive descent segment: '..'", str, col, tok)
|
||||
}
|
||||
sel = append(sel, segment{".", true, false, false, nil, "", 0})
|
||||
sel = append(sel, segment{str: ".", identity: true})
|
||||
case "[]":
|
||||
sel = append(sel, segment{tok, false, opt, true, nil, "", 0})
|
||||
sel = append(sel, segment{str: tok, optional: opt, iterator: true})
|
||||
default:
|
||||
if strings.HasPrefix(seg, "[") && strings.HasSuffix(seg, "]") {
|
||||
lookup := seg[1 : len(seg)-1]
|
||||
|
||||
@@ -2,7 +2,6 @@ package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ipld/go-ipld-prime"
|
||||
@@ -21,6 +20,12 @@ func (s Selector) Select(subject ipld.Node) (ipld.Node, []ipld.Node, error) {
|
||||
return resolve(s, subject, nil)
|
||||
}
|
||||
|
||||
// MatchPath tells if the selector operates on the given (string only) path segments.
|
||||
// It returns the segments that didn't get consumed by the matching.
|
||||
func (s Selector) MatchPath(pathSegment ...string) (bool, []string) {
|
||||
return matchPath(s, pathSegment)
|
||||
}
|
||||
|
||||
func (s Selector) String() string {
|
||||
var res strings.Builder
|
||||
for _, seg := range s {
|
||||
@@ -29,12 +34,6 @@ func (s Selector) String() string {
|
||||
return res.String()
|
||||
}
|
||||
|
||||
var (
|
||||
indexRegex = regexp.MustCompile(`^-?\d+$`)
|
||||
sliceRegex = regexp.MustCompile(`^((\-?\d+:\-?\d*)|(\-?\d*:\-?\d+))$`)
|
||||
fieldRegex = regexp.MustCompile(`^\.[a-zA-Z_]*?$`)
|
||||
)
|
||||
|
||||
type segment struct {
|
||||
str string
|
||||
identity bool
|
||||
@@ -316,7 +315,7 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
default: // Index()
|
||||
at = append(at, fmt.Sprintf("%d", seg.Index()))
|
||||
if cur == nil {
|
||||
if seg.Optional() {
|
||||
@@ -378,6 +377,39 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No
|
||||
return cur, nil, nil
|
||||
}
|
||||
|
||||
func matchPath(sel Selector, path []string) (bool, []string) {
|
||||
for _, seg := range sel {
|
||||
if len(path) == 0 {
|
||||
return true, path
|
||||
}
|
||||
switch {
|
||||
case seg.Identity():
|
||||
continue
|
||||
|
||||
case seg.Iterator():
|
||||
// we have reached a [] iterator, it should have matched earlier
|
||||
return false, nil
|
||||
|
||||
case seg.Field() != "":
|
||||
// if exact match on the segment, we continue
|
||||
if path[0] == seg.Field() {
|
||||
path = path[1:]
|
||||
continue
|
||||
}
|
||||
return false, nil
|
||||
|
||||
case seg.Slice() != nil:
|
||||
// we have reached a [<int>:<int>] slicing, it should have matched earlier
|
||||
return false, nil
|
||||
|
||||
default: // Index()
|
||||
// we have reached a [<int>] indexing, it should have matched earlier
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, path
|
||||
}
|
||||
|
||||
// resolveSliceIndices resolves the start and end indices for slicing a list or byte array.
|
||||
//
|
||||
// It takes the slice indices from the selector segment and the length of the list or byte array,
|
||||
|
||||
@@ -431,6 +431,32 @@ func TestSelect(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatch(t *testing.T) {
|
||||
for _, tc := range []struct {
|
||||
sel string
|
||||
path []string
|
||||
want bool
|
||||
remaining []string
|
||||
}{
|
||||
{sel: ".foo.bar", path: []string{"foo", "bar"}, want: true, remaining: []string{}},
|
||||
{sel: ".foo.bar", path: []string{"foo"}, want: true, remaining: []string{}},
|
||||
{sel: ".foo.bar", path: []string{"foo", "bar", "baz"}, want: true, remaining: []string{"baz"}},
|
||||
{sel: ".foo.bar", path: []string{"foo", "faa"}, want: false},
|
||||
{sel: ".foo.[]", path: []string{"foo", "faa"}, want: false},
|
||||
{sel: ".foo.[]", path: []string{"foo"}, want: true, remaining: []string{}},
|
||||
{sel: ".foo.bar?", path: []string{"foo"}, want: true, remaining: []string{}},
|
||||
{sel: ".foo.bar?", path: []string{"foo", "bar"}, want: true, remaining: []string{}},
|
||||
{sel: ".foo.bar?", path: []string{"foo", "baz"}, want: false},
|
||||
} {
|
||||
t.Run(tc.sel, func(t *testing.T) {
|
||||
sel := MustParse(tc.sel)
|
||||
res, remain := sel.MatchPath(tc.path...)
|
||||
require.Equal(t, tc.want, res)
|
||||
require.EqualValues(t, tc.remaining, remain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParse(f *testing.F) {
|
||||
selectorCorpus := []string{
|
||||
`.`, `.[]`, `.[]?`, `.[][]?`, `.x`, `.["x"]`, `.[0]`, `.[-1]`, `.[0]`,
|
||||
|
||||
Reference in New Issue
Block a user