diff --git a/capability/policy/glob.go b/capability/policy/glob.go index 34d42a0..56cb505 100644 --- a/capability/policy/glob.go +++ b/capability/policy/glob.go @@ -1,7 +1,11 @@ package policy -// validateGlobPattern ensures the pattern conforms to the spec: only '*' and escaped '\*' are allowed. -func validateGlobPattern(pattern string) bool { +import "fmt" + +type glob string + +// parseGlob ensures that the pattern conforms to the spec: only '*' and escaped '\*' are allowed. +func parseGlob(pattern string) (glob, error) { for i := 0; i < len(pattern); i++ { if pattern[i] == '*' { continue @@ -15,15 +19,23 @@ func validateGlobPattern(pattern string) bool { continue } if pattern[i] == '\\' { - return false // invalid escape sequence + return "", fmt.Errorf("invalid escape sequence") } } - return true + return glob(pattern), nil } -// globMatch matches a string against a pattern with * wildcards, handling escaped '\*' literals. -func globMatch(pattern, str string) bool { +func mustParseGlob(pattern string) glob { + g, err := parseGlob(pattern) + if err != nil { + panic(err) + } + return g +} + +// Match matches a string against the glob pattern with * wildcards, handling escaped '\*' literals. +func (pattern glob) Match(str string) bool { // i is the index for the pattern // j is the index for the string var i, j int diff --git a/capability/policy/glob_test.go b/capability/policy/glob_test.go index 31451a0..a89bce3 100644 --- a/capability/policy/glob_test.go +++ b/capability/policy/glob_test.go @@ -3,7 +3,7 @@ package policy import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSimpleGlobMatch(t *testing.T) { @@ -56,7 +56,18 @@ func TestSimpleGlobMatch(t *testing.T) { for _, tt := range tests { t.Run(tt.pattern+"_"+tt.str, func(t *testing.T) { - assert.Equal(t, tt.matches, globMatch(tt.pattern, tt.str)) + g, err := parseGlob(tt.pattern) + require.NoError(t, err) + require.Equal(t, tt.matches, g.Match(tt.str)) }) } } + +func BenchmarkGlob(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + g := mustParseGlob("Alice\\*, Bob*, Carol.") + g.Match("Alice*, Bob*, Carol!") + } +} diff --git a/capability/policy/match.go b/capability/policy/match.go index ee65967..5313af4 100644 --- a/capability/policy/match.go +++ b/capability/policy/match.go @@ -112,7 +112,7 @@ func matchStatement(statement Statement, node ipld.Node) bool { if err != nil { return false } - return s.pattern.match(v) + return s.pattern.Match(v) } case KindAll: if s, ok := statement.(quantifier); ok { diff --git a/capability/policy/policy.go b/capability/policy/policy.go index 157be8f..e6f385e 100644 --- a/capability/policy/policy.go +++ b/capability/policy/policy.go @@ -3,8 +3,6 @@ package policy // https://github.com/ucan-wg/delegation/blob/4094d5878b58f5d35055a3b93fccda0b8329ebae/README.md#policy import ( - "errors" - "github.com/ipld/go-ipld-prime" "github.com/ucan-wg/go-ucan/capability/policy/selector" @@ -89,23 +87,9 @@ func Or(stmts ...Statement) Statement { return connective{kind: KindOr, statements: stmts} } -type wildcardPattern string - -func parseWildcardPattern(pattern string) (wildcardPattern, error) { - if !validateGlobPattern(pattern) { - return "", errors.New("invalid wildcard pattern") - } - - return wildcardPattern(pattern), nil -} - -func (wp wildcardPattern) match(str string) bool { - return globMatch(string(wp), str) -} - type wildcard struct { selector selector.Selector - pattern wildcardPattern + pattern glob } func (n wildcard) Kind() string { @@ -113,12 +97,12 @@ func (n wildcard) Kind() string { } func Like(selector selector.Selector, pattern string) (Statement, error) { - parsedPattern, err := parseWildcardPattern(pattern) + g, err := parseGlob(pattern) if err != nil { return nil, err } - return wildcard{selector: selector, pattern: parsedPattern}, nil + return wildcard{selector: selector, pattern: g}, nil } type quantifier struct {