feat: simplify

This commit is contained in:
Alan Shaw
2024-08-20 22:27:56 +02:00
parent e871b98cba
commit e7bbe02143
5 changed files with 302 additions and 181 deletions

143
match.go
View File

@@ -5,35 +5,37 @@ import (
"fmt"
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/must"
"github.com/storacha-network/go-ucanto/core/policy/selector"
)
// Match determines if the IPLD node matches the policy document.
func Match(policy Policy, node ipld.Node) (bool, error) {
func Match(policy Policy, node ipld.Node) bool {
for _, stmt := range policy {
ok, err := matchStatement(stmt, node)
if err != nil || !ok {
return ok, err
ok := matchStatement(stmt, node)
if !ok {
return ok
}
}
return true, nil
return true
}
func matchStatement(statement Statement, node ipld.Node) (bool, error) {
func matchStatement(statement Statement, node ipld.Node) bool {
switch statement.Kind() {
case Kind_Equal:
if s, ok := statement.(EqualityStatement); ok {
one, _, err := selector.Select(s.Selector(), node)
if err != nil || one == nil {
return false, nil
return false
}
return isDeepEqual(s.Value(), one)
return datamodel.DeepEqual(s.Value(), one)
}
case Kind_GreaterThan:
if s, ok := statement.(InequalityStatement); ok {
one, _, err := selector.Select(s.Selector(), node)
if err != nil || one == nil {
return false, nil
return false
}
return isOrdered(s.Value(), one, gt)
}
@@ -41,7 +43,7 @@ func matchStatement(statement Statement, node ipld.Node) (bool, error) {
if s, ok := statement.(InequalityStatement); ok {
one, _, err := selector.Select(s.Selector(), node)
if err != nil || one == nil {
return false, nil
return false
}
return isOrdered(s.Value(), one, gte)
}
@@ -49,7 +51,7 @@ func matchStatement(statement Statement, node ipld.Node) (bool, error) {
if s, ok := statement.(InequalityStatement); ok {
one, _, err := selector.Select(s.Selector(), node)
if err != nil || one == nil {
return false, nil
return false
}
return isOrdered(s.Value(), one, lt)
}
@@ -57,147 +59,64 @@ func matchStatement(statement Statement, node ipld.Node) (bool, error) {
if s, ok := statement.(InequalityStatement); ok {
one, _, err := selector.Select(s.Selector(), node)
if err != nil || one == nil {
return false, nil
return false
}
return isOrdered(s.Value(), one, lte)
}
case Kind_Negation:
if s, ok := statement.(NegationStatement); ok {
r, err := matchStatement(s.Value(), node)
if err != nil {
return false, err
}
return !r, err
return !matchStatement(s.Value(), node)
}
case Kind_Conjunction:
if s, ok := statement.(ConjunctionStatement); ok {
for _, cs := range s.Value() {
r, err := matchStatement(cs, node)
if err != nil {
return false, err
}
r := matchStatement(cs, node)
if !r {
return false, nil
return false
}
}
return true, nil
return true
}
case Kind_Disjunction:
if s, ok := statement.(DisjunctionStatement); ok {
if len(s.Value()) == 0 {
return true, nil
return true
}
for _, cs := range s.Value() {
r, err := matchStatement(cs, node)
if err != nil {
return false, err
}
r := matchStatement(cs, node)
if r {
return true, nil
return true
}
}
return false, nil
return false
}
case Kind_Wildcard:
case Kind_Universal:
case Kind_Existential:
}
return false, fmt.Errorf("unimplemented statement kind: %s", statement.Kind())
panic(fmt.Errorf("unimplemented statement kind: %s", statement.Kind()))
}
func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) (bool, error) {
func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool {
if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int {
a, err := actual.AsInt()
if err != nil {
return false, fmt.Errorf("extracting node int: %w", err)
}
b, err := expected.AsInt()
if err != nil {
return false, fmt.Errorf("extracting selector int: %w", err)
}
return satisfies(cmp.Compare(a, b)), nil
a := must.Int(actual)
b := must.Int(expected)
return satisfies(cmp.Compare(a, b))
}
if expected.Kind() == ipld.Kind_Float && actual.Kind() == ipld.Kind_Float {
a, err := actual.AsFloat()
if err != nil {
return false, fmt.Errorf("extracting node float: %w", err)
panic(fmt.Errorf("extracting node float: %w", err))
}
b, err := expected.AsFloat()
if err != nil {
return false, fmt.Errorf("extracting selector float: %w", err)
panic(fmt.Errorf("extracting selector float: %w", err))
}
return satisfies(cmp.Compare(a, b)), nil
return satisfies(cmp.Compare(a, b))
}
return false, fmt.Errorf("unsupported IPLD kinds in ordered comparison: %s %s", expected.Kind(), actual.Kind())
}
func isDeepEqual(expected ipld.Node, actual ipld.Node) (bool, error) {
if expected.Kind() != actual.Kind() {
return false, nil
}
// TODO: should be easy enough to do the basic types, map, struct and list
// might be harder.
switch expected.Kind() {
case ipld.Kind_String:
a, err := actual.AsString()
if err != nil {
return false, fmt.Errorf("extracting node string: %w", err)
}
b, err := expected.AsString()
if err != nil {
return false, fmt.Errorf("extracting selector string: %w", err)
}
return a == b, nil
case ipld.Kind_Int:
if actual.Kind() != ipld.Kind_Int {
return false, nil
}
a, err := actual.AsInt()
if err != nil {
return false, fmt.Errorf("extracting node int: %w", err)
}
b, err := expected.AsInt()
if err != nil {
return false, fmt.Errorf("extracting selector int: %w", err)
}
return a == b, nil
case ipld.Kind_Float:
if actual.Kind() != ipld.Kind_Float {
return false, nil
}
a, err := actual.AsFloat()
if err != nil {
return false, fmt.Errorf("extracting node float: %w", err)
}
b, err := expected.AsFloat()
if err != nil {
return false, fmt.Errorf("extracting selector float: %w", err)
}
return a == b, nil
case ipld.Kind_Bool:
a, err := actual.AsBool()
if err != nil {
return false, fmt.Errorf("extracting node boolean: %w", err)
}
b, err := expected.AsBool()
if err != nil {
return false, fmt.Errorf("extracting selector node boolean: %w", err)
}
return a == b, nil
case ipld.Kind_Link:
a, err := actual.AsLink()
if err != nil {
return false, fmt.Errorf("extracting node link: %w", err)
}
b, err := expected.AsLink()
if err != nil {
return false, fmt.Errorf("extracting selector node link: %w", err)
}
return a.Binary() == b.Binary(), nil
}
return false, fmt.Errorf("unsupported IPLD kind in equality comparison: %s", expected.Kind())
return false
}
func gt(order int) bool { return order == 1 }

View File

@@ -19,18 +19,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.String("test"))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("test2"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Int(138))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -41,18 +38,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.Int(138))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Int(1138))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("138"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -63,18 +57,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.Float(1.138))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Float(11.38))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("138"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -88,18 +79,15 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse("."), literal.Link(l0))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.Link(l1))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse("."), literal.String("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -113,23 +101,19 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse(".foo"), literal.String("bar"))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse(".[\"foo\"]"), literal.String("bar"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse(".foo"), literal.String("baz"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{Equal(selector.MustParse(".foobar"), literal.String("bar"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -142,13 +126,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Equal(selector.MustParse(".[0]"), literal.String("foo"))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Equal(selector.MustParse(".[1]"), literal.String("foo"))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -159,8 +141,7 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThan(selector.MustParse("."), literal.Int(1))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
})
@@ -171,13 +152,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(1))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(138))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.True(t, ok)
})
@@ -188,8 +167,7 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThan(selector.MustParse("."), literal.Float(1))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
})
@@ -200,13 +178,37 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1.38))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.True(t, ok)
})
t.Run("inequality lt int", func(t *testing.T) {
np := basicnode.Prototype.Int
nb := np.NewBuilder()
nb.AssignInt(138)
nd := nb.Build()
pol := Policy{LessThan(selector.MustParse("."), literal.Int(1138))}
ok := Match(pol, nd)
require.True(t, ok)
})
t.Run("inequality lte int", func(t *testing.T) {
np := basicnode.Prototype.Int
nb := np.NewBuilder()
nb.AssignInt(138)
nd := nb.Build()
pol := Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(1138))}
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(138))}
ok = Match(pol, nd)
require.True(t, ok)
})
@@ -217,13 +219,11 @@ func TestMatch(t *testing.T) {
nd := nb.Build()
pol := Policy{Not(Equal(selector.MustParse("."), literal.Bool(true)))}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{Not(Equal(selector.MustParse("."), literal.Bool(false)))}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
})
@@ -239,8 +239,7 @@ func TestMatch(t *testing.T) {
LessThan(selector.MustParse("."), literal.Int(1138)),
),
}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{
@@ -249,13 +248,11 @@ func TestMatch(t *testing.T) {
Equal(selector.MustParse("."), literal.Int(1138)),
),
}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{And()}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.True(t, ok)
})
@@ -271,8 +268,7 @@ func TestMatch(t *testing.T) {
LessThan(selector.MustParse("."), literal.Int(1138)),
),
}
ok, err := Match(pol, nd)
require.NoError(t, err)
ok := Match(pol, nd)
require.True(t, ok)
pol = Policy{
@@ -281,13 +277,11 @@ func TestMatch(t *testing.T) {
Equal(selector.MustParse("."), literal.Int(1138)),
),
}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.False(t, ok)
pol = Policy{Or()}
ok, err = Match(pol, nd)
require.NoError(t, err)
ok = Match(pol, nd)
require.True(t, ok)
})
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/schema"
)
// Selector describes a UCAN policy selector, as specified here:
@@ -238,6 +239,8 @@ func MustParse(sel string) Selector {
return s
}
// Select uses a selector to extract an IPLD node or set of nodes from the
// passed subject node.
func Select(sel Selector, subject ipld.Node) (ipld.Node, []ipld.Node, error) {
return resolve(sel, subject, nil)
}
@@ -256,12 +259,12 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No
break
}
i, v, err := it.Next()
k, v, err := it.Next()
if err != nil {
return nil, nil, err
}
key := fmt.Sprintf("%d", i)
key := fmt.Sprintf("%d", k)
o, m, err := resolve(sel[i+1:], v, append(at[:], key))
if err != nil {
return nil, nil, err
@@ -311,7 +314,7 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No
if cur != nil && cur.Kind() == datamodel.Kind_Map {
n, err := cur.LookupByString(seg.Field())
if err != nil {
if _, ok := err.(datamodel.ErrNotExists); ok {
if isMissing(err) {
if seg.Optional() {
cur = nil
} else {
@@ -342,7 +345,7 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No
if cur != nil && cur.Kind() == datamodel.Kind_List {
n, err := cur.LookupByIndex(int64(seg.Index()))
if err != nil {
if _, ok := err.(datamodel.ErrNotExists); ok {
if isMissing(err) {
if seg.Optional() {
cur = nil
} else {
@@ -371,6 +374,19 @@ func kindString(n datamodel.Node) string {
return n.Kind().String()
}
func isMissing(err error) bool {
if _, ok := err.(datamodel.ErrNotExists); ok {
return true
}
if _, ok := err.(schema.ErrNoSuchField); ok {
return true
}
if _, ok := err.(schema.ErrInvalidKey); ok {
return true
}
return false
}
type ResolutionError interface {
error
Name() string

View File

@@ -4,6 +4,10 @@ import (
"fmt"
"testing"
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/must"
"github.com/ipld/go-ipld-prime/node/bindnode"
"github.com/ipld/go-ipld-prime/printer"
"github.com/stretchr/testify/require"
)
@@ -171,9 +175,11 @@ func TestParse(t *testing.T) {
})
t.Run("nesting", func(t *testing.T) {
sel, err := Parse(`.foo.["bar"].[138]?.baz[1:]`)
str := `.foo.["bar"].[138]?.baz[1:]`
sel, err := Parse(str)
require.NoError(t, err)
printSegments(sel)
require.Equal(t, str, sel.String())
require.Equal(t, 7, len(sel))
require.False(t, sel[0].Identity())
require.False(t, sel[0].Optional())
@@ -237,3 +243,189 @@ func printSegments(s Selector) {
fmt.Printf("%d: %s\n", i, seg.String())
}
}
func TestSelect(t *testing.T) {
type name struct {
First string
Middle *string
Last string
}
type interest struct {
Name string
Outdoor bool
Experience int
}
type user struct {
Name name
Age int
Nationalities []string
Interests []interest
}
ts, err := ipld.LoadSchemaBytes([]byte(`
type User struct {
name Name
age Int
nationalities [String]
interests [Interest]
}
type Name struct {
first String
middle optional String
last String
}
type Interest struct {
name String
outdoor Bool
experience Int
}
`))
require.NoError(t, err)
typ := ts.TypeByName("User")
am := "Joan"
alice := user{
Name: name{First: "Alice", Middle: &am, Last: "Wonderland"},
Age: 24,
Nationalities: []string{"British"},
Interests: []interest{
{Name: "Cycling", Outdoor: true, Experience: 4},
{Name: "Chess", Outdoor: false, Experience: 2},
},
}
bob := user{
Name: name{First: "Bob", Last: "Builder"},
Age: 35,
Nationalities: []string{"Canadian", "South African"},
Interests: []interest{
{Name: "Snowboarding", Outdoor: true, Experience: 8},
{Name: "Reading", Outdoor: false, Experience: 25},
},
}
anode := bindnode.Wrap(&alice, typ)
bnode := bindnode.Wrap(&bob, typ)
t.Run("identity", func(t *testing.T) {
sel, err := Parse(".")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.NoError(t, err)
require.NotEmpty(t, one)
require.Empty(t, many)
fmt.Println(printer.Sprint(one))
age := must.Int(must.Node(one.LookupByString("age")))
require.Equal(t, int64(alice.Age), age)
})
t.Run("nested property", func(t *testing.T) {
sel, err := Parse(".name.first")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.NoError(t, err)
require.NotEmpty(t, one)
require.Empty(t, many)
fmt.Println(printer.Sprint(one))
name := must.String(one)
require.Equal(t, alice.Name.First, name)
one, many, err = Select(sel, bnode)
require.NoError(t, err)
require.NotEmpty(t, one)
require.Empty(t, many)
fmt.Println(printer.Sprint(one))
name = must.String(one)
require.Equal(t, bob.Name.First, name)
})
t.Run("optional nested property", func(t *testing.T) {
sel, err := Parse(".name.middle?")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.NoError(t, err)
require.NotEmpty(t, one)
require.Empty(t, many)
fmt.Println(printer.Sprint(one))
name := must.String(one)
require.Equal(t, *alice.Name.Middle, name)
one, many, err = Select(sel, bnode)
require.NoError(t, err)
require.Empty(t, one)
require.Empty(t, many)
})
t.Run("not exists", func(t *testing.T) {
sel, err := Parse(".name.foo")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.Error(t, err)
require.Empty(t, one)
require.Empty(t, many)
fmt.Println(err)
if _, ok := err.(ResolutionError); !ok {
t.Fatalf("error was not a resolution error")
}
})
t.Run("optional not exists", func(t *testing.T) {
sel, err := Parse(".name.foo?")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.NoError(t, err)
require.Empty(t, one)
require.Empty(t, many)
})
t.Run("iterator", func(t *testing.T) {
sel, err := Parse(".interests[]")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.NoError(t, err)
require.Empty(t, one)
require.NotEmpty(t, many)
for _, n := range many {
fmt.Println(printer.Sprint(n))
}
iname := must.String(must.Node(many[0].LookupByString("name")))
require.Equal(t, alice.Interests[0].Name, iname)
iname = must.String(must.Node(many[1].LookupByString("name")))
require.Equal(t, alice.Interests[1].Name, iname)
})
t.Run("map iterator", func(t *testing.T) {
sel, err := Parse(".interests[0][]")
require.NoError(t, err)
one, many, err := Select(sel, anode)
require.NoError(t, err)
require.Empty(t, one)
require.NotEmpty(t, many)
for _, n := range many {
fmt.Println(printer.Sprint(n))
}
require.Equal(t, alice.Interests[0].Name, must.String(many[0]))
require.Equal(t, alice.Interests[0].Experience, int(must.Int(many[2])))
})
}