glob: a bit of reshaping, and a benchmark

This commit is contained in:
Michael Muré
2024-09-18 11:24:37 +02:00
parent a19d3505fe
commit ac73cae3ec
4 changed files with 35 additions and 28 deletions

View File

@@ -1,7 +1,11 @@
package policy
// validateGlobPattern ensures the pattern conforms to the spec: only '*' and escaped '\*' are allowed.
func validateGlobPattern(pattern string) bool {
import "fmt"
type glob string
// parseGlob ensures that the pattern conforms to the spec: only '*' and escaped '\*' are allowed.
func parseGlob(pattern string) (glob, error) {
for i := 0; i < len(pattern); i++ {
if pattern[i] == '*' {
continue
@@ -15,15 +19,23 @@ func validateGlobPattern(pattern string) bool {
continue
}
if pattern[i] == '\\' {
return false // invalid escape sequence
return "", fmt.Errorf("invalid escape sequence")
}
}
return true
return glob(pattern), nil
}
// globMatch matches a string against a pattern with * wildcards, handling escaped '\*' literals.
func globMatch(pattern, str string) bool {
func mustParseGlob(pattern string) glob {
g, err := parseGlob(pattern)
if err != nil {
panic(err)
}
return g
}
// Match matches a string against the glob pattern with * wildcards, handling escaped '\*' literals.
func (pattern glob) Match(str string) bool {
// i is the index for the pattern
// j is the index for the string
var i, j int

View File

@@ -3,7 +3,7 @@ package policy
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSimpleGlobMatch(t *testing.T) {
@@ -56,7 +56,18 @@ func TestSimpleGlobMatch(t *testing.T) {
for _, tt := range tests {
t.Run(tt.pattern+"_"+tt.str, func(t *testing.T) {
assert.Equal(t, tt.matches, globMatch(tt.pattern, tt.str))
g, err := parseGlob(tt.pattern)
require.NoError(t, err)
require.Equal(t, tt.matches, g.Match(tt.str))
})
}
}
func BenchmarkGlob(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
g := mustParseGlob("Alice\\*, Bob*, Carol.")
g.Match("Alice*, Bob*, Carol!")
}
}

View File

@@ -112,7 +112,7 @@ func matchStatement(statement Statement, node ipld.Node) bool {
if err != nil {
return false
}
return s.pattern.match(v)
return s.pattern.Match(v)
}
case KindAll:
if s, ok := statement.(quantifier); ok {

View File

@@ -3,8 +3,6 @@ package policy
// https://github.com/ucan-wg/delegation/blob/4094d5878b58f5d35055a3b93fccda0b8329ebae/README.md#policy
import (
"errors"
"github.com/ipld/go-ipld-prime"
"github.com/ucan-wg/go-ucan/capability/policy/selector"
@@ -89,23 +87,9 @@ func Or(stmts ...Statement) Statement {
return connective{kind: KindOr, statements: stmts}
}
type wildcardPattern string
func parseWildcardPattern(pattern string) (wildcardPattern, error) {
if !validateGlobPattern(pattern) {
return "", errors.New("invalid wildcard pattern")
}
return wildcardPattern(pattern), nil
}
func (wp wildcardPattern) match(str string) bool {
return globMatch(string(wp), str)
}
type wildcard struct {
selector selector.Selector
pattern wildcardPattern
pattern glob
}
func (n wildcard) Kind() string {
@@ -113,12 +97,12 @@ func (n wildcard) Kind() string {
}
func Like(selector selector.Selector, pattern string) (Statement, error) {
parsedPattern, err := parseWildcardPattern(pattern)
g, err := parseGlob(pattern)
if err != nil {
return nil, err
}
return wildcard{selector: selector, pattern: parsedPattern}, nil
return wildcard{selector: selector, pattern: g}, nil
}
type quantifier struct {