diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index 23e4e59..48f813f 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -1,3 +1,4 @@ +// Package literal holds a collection of functions to create IPLD types to use in policies, selector and args. package literal import ( diff --git a/pkg/policy/match.go b/pkg/policy/match.go index 178b862..ed18587 100644 --- a/pkg/policy/match.go +++ b/pkg/policy/match.go @@ -9,7 +9,7 @@ import ( "github.com/ipld/go-ipld-prime/must" ) -// Match determines if the IPLD node matches the policy document. +// Match determines if the IPLD node satisfies the policy. func (p Policy) Match(node datamodel.Node) bool { for _, stmt := range p { ok := matchStatement(stmt, node) @@ -20,6 +20,20 @@ func (p Policy) Match(node datamodel.Node) bool { return true } +// Filter performs a recursive filtering of the Statement, and prunes what doesn't match the given path +func (p Policy) Filter(path ...string) Policy { + var filtered Policy + + for _, stmt := range p { + newChild, remain := filter(stmt, path) + if newChild != nil && len(remain) == 0 { + filtered = append(filtered, newChild) + } + } + + return filtered +} + func matchStatement(statement Statement, node ipld.Node) bool { switch statement.Kind() { case KindEqual: @@ -153,6 +167,70 @@ func matchStatement(statement Statement, node ipld.Node) bool { panic(fmt.Errorf("unimplemented statement kind: %s", statement.Kind())) } +// filter performs a recursive filtering of the Statement, and prunes what doesn't match the given path +func filter(stmt Statement, path []string) (Statement, []string) { + // For each kind, we do some of the following if it applies: + // - test the path against the selector, consuming segments + // - for terminal statements (equality, wildcard), require all the segments to have been consumed + // - recursively filter child (negation, quantifier) or children (connective) statements with the remaining path + switch stmt.(type) { + case equality: + match, remain := stmt.(equality).selector.MatchPath(path...) + if match && len(remain) == 0 { + return stmt, remain + } + return nil, nil + case negation: + newChild, remain := filter(stmt.(negation).statement, path) + if newChild != nil && len(remain) == 0 { + return negation{ + statement: newChild, + }, nil + } + return nil, nil + case connective: + var newChildren []Statement + for _, child := range stmt.(connective).statements { + newChild, remain := filter(child, path) + if newChild != nil && len(remain) == 0 { + newChildren = append(newChildren, newChild) + } + } + if len(newChildren) == 0 { + return nil, nil + } + return connective{ + kind: stmt.(connective).kind, + statements: newChildren, + }, nil + case wildcard: + match, remain := stmt.(wildcard).selector.MatchPath(path...) + if match && len(remain) == 0 { + return stmt, remain + } + return nil, nil + case quantifier: + match, remain := stmt.(quantifier).selector.MatchPath(path...) + if match && len(remain) == 0 { + return stmt, remain + } + if !match { + return nil, nil + } + newChild, remain := filter(stmt.(quantifier).statement, remain) + if newChild != nil && len(remain) == 0 { + return quantifier{ + kind: stmt.(quantifier).kind, + selector: stmt.(quantifier).selector, + statement: newChild, + }, nil + } + return nil, nil + default: + panic(fmt.Errorf("unimplemented statement kind: %s", stmt.Kind())) + } +} + func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool { if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int { a := must.Int(actual) diff --git a/pkg/policy/match_test.go b/pkg/policy/match_test.go index 9e3de4a..0aac982 100644 --- a/pkg/policy/match_test.go +++ b/pkg/policy/match_test.go @@ -2,6 +2,7 @@ package policy import ( "fmt" + "strings" "testing" "github.com/ipfs/go-cid" @@ -512,3 +513,103 @@ func FuzzMatch(f *testing.F) { policy.Match(dataNode) }) } + +func TestPolicyFilter(t *testing.T) { + pol := MustConstruct( + Any(".http", And( + Equal(".method", literal.String("GET")), + Equal(".path", literal.String("/foo")), + )), + Equal(".http", literal.String("foobar")), + All(".jsonrpc.foo", Or( + Not(Equal(".bar", literal.String("foo"))), + Equal(".", literal.String("foo")), + Like(".boo", "abcd"), + Like(".boo", "*bcd"), + )), + ) + + for _, tc := range []struct { + path string + expected Policy + }{ + { + path: "http", + expected: MustConstruct( + Any(".http", And( + Equal(".method", literal.String("GET")), + Equal(".path", literal.String("/foo")), + )), + Equal(".http", literal.String("foobar")), + ), + }, + { + path: "http,method", + expected: MustConstruct( + Any(".http", And( + Equal(".method", literal.String("GET")), + )), + ), + }, + { + path: "http,path", + expected: MustConstruct( + Any(".http", And( + Equal(".path", literal.String("/foo")), + )), + ), + }, + { + path: "http,foo", + expected: Policy{}, + }, + { + path: "jsonrpc", + expected: MustConstruct( + All(".jsonrpc.foo", Or( + Not(Equal(".bar", literal.String("foo"))), + Equal(".", literal.String("foo")), + Like(".boo", "abcd"), + Like(".boo", "*bcd"), + )), + ), + }, + { + path: "jsonrpc,baz", + expected: Policy{}, + }, + { + path: "jsonrpc,foo", + expected: MustConstruct( + All(".jsonrpc.foo", Or( + Not(Equal(".bar", literal.String("foo"))), + Equal(".", literal.String("foo")), + Like(".boo", "abcd"), + Like(".boo", "*bcd"), + )), + ), + }, + { + path: "jsonrpc,foo,bar", + expected: MustConstruct( + All(".jsonrpc.foo", Or( + Not(Equal(".bar", literal.String("foo"))), + )), + ), + }, + { + path: "jsonrpc,foo,boo", + expected: MustConstruct( + All(".jsonrpc.foo", Or( + Like(".boo", "abcd"), + Like(".boo", "*bcd"), + )), + ), + }, + } { + t.Run(tc.path, func(t *testing.T) { + res := pol.Filter(strings.Split(tc.path, ",")...) + require.Equal(t, tc.expected.String(), res.String()) + }) + } +} diff --git a/pkg/policy/selector/parsing.go b/pkg/policy/selector/parsing.go index 887be6c..16843e2 100644 --- a/pkg/policy/selector/parsing.go +++ b/pkg/policy/selector/parsing.go @@ -2,11 +2,17 @@ package selector import ( "fmt" + "regexp" "strconv" "strings" ) -var identity = Selector{segment{".", true, false, false, nil, "", 0}} +var ( + identity = Selector{segment{str: ".", identity: true}} + indexRegex = regexp.MustCompile(`^-?\d+$`) + sliceRegex = regexp.MustCompile(`^((\-?\d+:\-?\d*)|(\-?\d*:\-?\d+))$`) + fieldRegex = regexp.MustCompile(`^\.[a-zA-Z_]*?$`) +) func Parse(str string) (Selector, error) { if len(str) == 0 { @@ -32,9 +38,9 @@ func Parse(str string) (Selector, error) { if len(sel) > 0 && sel[len(sel)-1].Identity() { return nil, newParseError("selector contains unsupported recursive descent segment: '..'", str, col, tok) } - sel = append(sel, segment{".", true, false, false, nil, "", 0}) + sel = append(sel, segment{str: ".", identity: true}) case "[]": - sel = append(sel, segment{tok, false, opt, true, nil, "", 0}) + sel = append(sel, segment{str: tok, optional: opt, iterator: true}) default: if strings.HasPrefix(seg, "[") && strings.HasSuffix(seg, "]") { lookup := seg[1 : len(seg)-1] diff --git a/pkg/policy/selector/selector.go b/pkg/policy/selector/selector.go index f5ed2ad..b0c49ac 100644 --- a/pkg/policy/selector/selector.go +++ b/pkg/policy/selector/selector.go @@ -2,7 +2,6 @@ package selector import ( "fmt" - "regexp" "strings" "github.com/ipld/go-ipld-prime" @@ -21,6 +20,12 @@ func (s Selector) Select(subject ipld.Node) (ipld.Node, []ipld.Node, error) { return resolve(s, subject, nil) } +// MatchPath tells if the selector operates on the given (string only) path segments. +// It returns the segments that didn't get consumed by the matching. +func (s Selector) MatchPath(pathSegment ...string) (bool, []string) { + return matchPath(s, pathSegment) +} + func (s Selector) String() string { var res strings.Builder for _, seg := range s { @@ -29,12 +34,6 @@ func (s Selector) String() string { return res.String() } -var ( - indexRegex = regexp.MustCompile(`^-?\d+$`) - sliceRegex = regexp.MustCompile(`^((\-?\d+:\-?\d*)|(\-?\d*:\-?\d+))$`) - fieldRegex = regexp.MustCompile(`^\.[a-zA-Z_]*?$`) -) - type segment struct { str string identity bool @@ -316,7 +315,7 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No } } - default: + default: // Index() at = append(at, fmt.Sprintf("%d", seg.Index())) if cur == nil { if seg.Optional() { @@ -378,6 +377,39 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No return cur, nil, nil } +func matchPath(sel Selector, path []string) (bool, []string) { + for _, seg := range sel { + if len(path) == 0 { + return true, path + } + switch { + case seg.Identity(): + continue + + case seg.Iterator(): + // we have reached a [] iterator, it should have matched earlier + return false, nil + + case seg.Field() != "": + // if exact match on the segment, we continue + if path[0] == seg.Field() { + path = path[1:] + continue + } + return false, nil + + case seg.Slice() != nil: + // we have reached a [:] slicing, it should have matched earlier + return false, nil + + default: // Index() + // we have reached a [] indexing, it should have matched earlier + return false, nil + } + } + return true, path +} + // resolveSliceIndices resolves the start and end indices for slicing a list or byte array. // // It takes the slice indices from the selector segment and the length of the list or byte array, diff --git a/pkg/policy/selector/selector_test.go b/pkg/policy/selector/selector_test.go index fdbfd69..93e7fa0 100644 --- a/pkg/policy/selector/selector_test.go +++ b/pkg/policy/selector/selector_test.go @@ -431,6 +431,32 @@ func TestSelect(t *testing.T) { }) } +func TestMatch(t *testing.T) { + for _, tc := range []struct { + sel string + path []string + want bool + remaining []string + }{ + {sel: ".foo.bar", path: []string{"foo", "bar"}, want: true, remaining: []string{}}, + {sel: ".foo.bar", path: []string{"foo"}, want: true, remaining: []string{}}, + {sel: ".foo.bar", path: []string{"foo", "bar", "baz"}, want: true, remaining: []string{"baz"}}, + {sel: ".foo.bar", path: []string{"foo", "faa"}, want: false}, + {sel: ".foo.[]", path: []string{"foo", "faa"}, want: false}, + {sel: ".foo.[]", path: []string{"foo"}, want: true, remaining: []string{}}, + {sel: ".foo.bar?", path: []string{"foo"}, want: true, remaining: []string{}}, + {sel: ".foo.bar?", path: []string{"foo", "bar"}, want: true, remaining: []string{}}, + {sel: ".foo.bar?", path: []string{"foo", "baz"}, want: false}, + } { + t.Run(tc.sel, func(t *testing.T) { + sel := MustParse(tc.sel) + res, remain := sel.MatchPath(tc.path...) + require.Equal(t, tc.want, res) + require.EqualValues(t, tc.remaining, remain) + }) + } +} + func FuzzParse(f *testing.F) { selectorCorpus := []string{ `.`, `.[]`, `.[]?`, `.[][]?`, `.x`, `.["x"]`, `.[0]`, `.[-1]`, `.[0]`,