diff --git a/match.go b/match.go index 95203ec..950ef23 100644 --- a/match.go +++ b/match.go @@ -5,35 +5,37 @@ import ( "fmt" "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/must" "github.com/storacha-network/go-ucanto/core/policy/selector" ) // Match determines if the IPLD node matches the policy document. -func Match(policy Policy, node ipld.Node) (bool, error) { +func Match(policy Policy, node ipld.Node) bool { for _, stmt := range policy { - ok, err := matchStatement(stmt, node) - if err != nil || !ok { - return ok, err + ok := matchStatement(stmt, node) + if !ok { + return ok } } - return true, nil + return true } -func matchStatement(statement Statement, node ipld.Node) (bool, error) { +func matchStatement(statement Statement, node ipld.Node) bool { switch statement.Kind() { case Kind_Equal: if s, ok := statement.(EqualityStatement); ok { one, _, err := selector.Select(s.Selector(), node) if err != nil || one == nil { - return false, nil + return false } - return isDeepEqual(s.Value(), one) + return datamodel.DeepEqual(s.Value(), one) } case Kind_GreaterThan: if s, ok := statement.(InequalityStatement); ok { one, _, err := selector.Select(s.Selector(), node) if err != nil || one == nil { - return false, nil + return false } return isOrdered(s.Value(), one, gt) } @@ -41,7 +43,7 @@ func matchStatement(statement Statement, node ipld.Node) (bool, error) { if s, ok := statement.(InequalityStatement); ok { one, _, err := selector.Select(s.Selector(), node) if err != nil || one == nil { - return false, nil + return false } return isOrdered(s.Value(), one, gte) } @@ -49,7 +51,7 @@ func matchStatement(statement Statement, node ipld.Node) (bool, error) { if s, ok := statement.(InequalityStatement); ok { one, _, err := selector.Select(s.Selector(), node) if err != nil || one == nil { - return false, nil + return false } return isOrdered(s.Value(), one, lt) } @@ -57,147 +59,64 @@ func matchStatement(statement Statement, node ipld.Node) (bool, error) { if s, ok := statement.(InequalityStatement); ok { one, _, err := selector.Select(s.Selector(), node) if err != nil || one == nil { - return false, nil + return false } return isOrdered(s.Value(), one, lte) } case Kind_Negation: if s, ok := statement.(NegationStatement); ok { - r, err := matchStatement(s.Value(), node) - if err != nil { - return false, err - } - return !r, err + return !matchStatement(s.Value(), node) } case Kind_Conjunction: if s, ok := statement.(ConjunctionStatement); ok { for _, cs := range s.Value() { - r, err := matchStatement(cs, node) - if err != nil { - return false, err - } + r := matchStatement(cs, node) if !r { - return false, nil + return false } } - return true, nil + return true } case Kind_Disjunction: if s, ok := statement.(DisjunctionStatement); ok { if len(s.Value()) == 0 { - return true, nil + return true } for _, cs := range s.Value() { - r, err := matchStatement(cs, node) - if err != nil { - return false, err - } + r := matchStatement(cs, node) if r { - return true, nil + return true } } - return false, nil + return false } case Kind_Wildcard: case Kind_Universal: case Kind_Existential: } - return false, fmt.Errorf("unimplemented statement kind: %s", statement.Kind()) + panic(fmt.Errorf("unimplemented statement kind: %s", statement.Kind())) } -func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) (bool, error) { +func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool { if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int { - a, err := actual.AsInt() - if err != nil { - return false, fmt.Errorf("extracting node int: %w", err) - } - b, err := expected.AsInt() - if err != nil { - return false, fmt.Errorf("extracting selector int: %w", err) - } - return satisfies(cmp.Compare(a, b)), nil + a := must.Int(actual) + b := must.Int(expected) + return satisfies(cmp.Compare(a, b)) } if expected.Kind() == ipld.Kind_Float && actual.Kind() == ipld.Kind_Float { a, err := actual.AsFloat() if err != nil { - return false, fmt.Errorf("extracting node float: %w", err) + panic(fmt.Errorf("extracting node float: %w", err)) } b, err := expected.AsFloat() if err != nil { - return false, fmt.Errorf("extracting selector float: %w", err) + panic(fmt.Errorf("extracting selector float: %w", err)) } - return satisfies(cmp.Compare(a, b)), nil + return satisfies(cmp.Compare(a, b)) } - return false, fmt.Errorf("unsupported IPLD kinds in ordered comparison: %s %s", expected.Kind(), actual.Kind()) -} - -func isDeepEqual(expected ipld.Node, actual ipld.Node) (bool, error) { - if expected.Kind() != actual.Kind() { - return false, nil - } - // TODO: should be easy enough to do the basic types, map, struct and list - // might be harder. - switch expected.Kind() { - case ipld.Kind_String: - a, err := actual.AsString() - if err != nil { - return false, fmt.Errorf("extracting node string: %w", err) - } - b, err := expected.AsString() - if err != nil { - return false, fmt.Errorf("extracting selector string: %w", err) - } - return a == b, nil - case ipld.Kind_Int: - if actual.Kind() != ipld.Kind_Int { - return false, nil - } - a, err := actual.AsInt() - if err != nil { - return false, fmt.Errorf("extracting node int: %w", err) - } - b, err := expected.AsInt() - if err != nil { - return false, fmt.Errorf("extracting selector int: %w", err) - } - return a == b, nil - case ipld.Kind_Float: - if actual.Kind() != ipld.Kind_Float { - return false, nil - } - a, err := actual.AsFloat() - if err != nil { - return false, fmt.Errorf("extracting node float: %w", err) - } - b, err := expected.AsFloat() - if err != nil { - return false, fmt.Errorf("extracting selector float: %w", err) - } - return a == b, nil - case ipld.Kind_Bool: - a, err := actual.AsBool() - if err != nil { - return false, fmt.Errorf("extracting node boolean: %w", err) - } - b, err := expected.AsBool() - if err != nil { - return false, fmt.Errorf("extracting selector node boolean: %w", err) - } - return a == b, nil - case ipld.Kind_Link: - a, err := actual.AsLink() - if err != nil { - return false, fmt.Errorf("extracting node link: %w", err) - } - b, err := expected.AsLink() - if err != nil { - return false, fmt.Errorf("extracting selector node link: %w", err) - } - return a.Binary() == b.Binary(), nil - } - return false, fmt.Errorf("unsupported IPLD kind in equality comparison: %s", expected.Kind()) + return false } func gt(order int) bool { return order == 1 } diff --git a/match_test.go b/match_test.go index 69e27dd..183f1bf 100644 --- a/match_test.go +++ b/match_test.go @@ -19,18 +19,15 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Equal(selector.MustParse("."), literal.String("test"))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.String("test2"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.Int(138))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -41,18 +38,15 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Equal(selector.MustParse("."), literal.Int(138))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.Int(1138))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.String("138"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -63,18 +57,15 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Equal(selector.MustParse("."), literal.Float(1.138))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.Float(11.38))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.String("138"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -88,18 +79,15 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Equal(selector.MustParse("."), literal.Link(l0))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.Link(l1))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{Equal(selector.MustParse("."), literal.String("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -113,23 +101,19 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Equal(selector.MustParse(".foo"), literal.String("bar"))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse(".[\"foo\"]"), literal.String("bar"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse(".foo"), literal.String("baz"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{Equal(selector.MustParse(".foobar"), literal.String("bar"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -142,13 +126,11 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Equal(selector.MustParse(".[0]"), literal.String("foo"))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Equal(selector.MustParse(".[1]"), literal.String("foo"))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -159,8 +141,7 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{GreaterThan(selector.MustParse("."), literal.Int(1))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) }) @@ -171,13 +152,11 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(1))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(138))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.True(t, ok) }) @@ -188,8 +167,7 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{GreaterThan(selector.MustParse("."), literal.Float(1))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) }) @@ -200,13 +178,37 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1.38))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) + require.True(t, ok) + }) + + t.Run("inequality lt int", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{LessThan(selector.MustParse("."), literal.Int(1138))} + ok := Match(pol, nd) + require.True(t, ok) + }) + + t.Run("inequality lte int", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(1138))} + ok := Match(pol, nd) + require.True(t, ok) + + pol = Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(138))} + ok = Match(pol, nd) require.True(t, ok) }) @@ -217,13 +219,11 @@ func TestMatch(t *testing.T) { nd := nb.Build() pol := Policy{Not(Equal(selector.MustParse("."), literal.Bool(true)))} - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{Not(Equal(selector.MustParse("."), literal.Bool(false)))} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) }) @@ -239,8 +239,7 @@ func TestMatch(t *testing.T) { LessThan(selector.MustParse("."), literal.Int(1138)), ), } - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{ @@ -249,13 +248,11 @@ func TestMatch(t *testing.T) { Equal(selector.MustParse("."), literal.Int(1138)), ), } - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{And()} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.True(t, ok) }) @@ -271,8 +268,7 @@ func TestMatch(t *testing.T) { LessThan(selector.MustParse("."), literal.Int(1138)), ), } - ok, err := Match(pol, nd) - require.NoError(t, err) + ok := Match(pol, nd) require.True(t, ok) pol = Policy{ @@ -281,13 +277,11 @@ func TestMatch(t *testing.T) { Equal(selector.MustParse("."), literal.Int(1138)), ), } - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.False(t, ok) pol = Policy{Or()} - ok, err = Match(pol, nd) - require.NoError(t, err) + ok = Match(pol, nd) require.True(t, ok) }) } diff --git a/statement.go b/policy.go similarity index 100% rename from statement.go rename to policy.go diff --git a/selector/selector.go b/selector/selector.go index 373a5b8..5f9d9a5 100644 --- a/selector/selector.go +++ b/selector/selector.go @@ -8,6 +8,7 @@ import ( "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/schema" ) // Selector describes a UCAN policy selector, as specified here: @@ -238,6 +239,8 @@ func MustParse(sel string) Selector { return s } +// Select uses a selector to extract an IPLD node or set of nodes from the +// passed subject node. func Select(sel Selector, subject ipld.Node) (ipld.Node, []ipld.Node, error) { return resolve(sel, subject, nil) } @@ -256,12 +259,12 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No break } - i, v, err := it.Next() + k, v, err := it.Next() if err != nil { return nil, nil, err } - key := fmt.Sprintf("%d", i) + key := fmt.Sprintf("%d", k) o, m, err := resolve(sel[i+1:], v, append(at[:], key)) if err != nil { return nil, nil, err @@ -311,7 +314,7 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No if cur != nil && cur.Kind() == datamodel.Kind_Map { n, err := cur.LookupByString(seg.Field()) if err != nil { - if _, ok := err.(datamodel.ErrNotExists); ok { + if isMissing(err) { if seg.Optional() { cur = nil } else { @@ -342,7 +345,7 @@ func resolve(sel Selector, subject ipld.Node, at []string) (ipld.Node, []ipld.No if cur != nil && cur.Kind() == datamodel.Kind_List { n, err := cur.LookupByIndex(int64(seg.Index())) if err != nil { - if _, ok := err.(datamodel.ErrNotExists); ok { + if isMissing(err) { if seg.Optional() { cur = nil } else { @@ -371,6 +374,19 @@ func kindString(n datamodel.Node) string { return n.Kind().String() } +func isMissing(err error) bool { + if _, ok := err.(datamodel.ErrNotExists); ok { + return true + } + if _, ok := err.(schema.ErrNoSuchField); ok { + return true + } + if _, ok := err.(schema.ErrInvalidKey); ok { + return true + } + return false +} + type ResolutionError interface { error Name() string diff --git a/selector/selector_test.go b/selector/selector_test.go index 3173cd1..b19282f 100644 --- a/selector/selector_test.go +++ b/selector/selector_test.go @@ -4,6 +4,10 @@ import ( "fmt" "testing" + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/must" + "github.com/ipld/go-ipld-prime/node/bindnode" + "github.com/ipld/go-ipld-prime/printer" "github.com/stretchr/testify/require" ) @@ -171,9 +175,11 @@ func TestParse(t *testing.T) { }) t.Run("nesting", func(t *testing.T) { - sel, err := Parse(`.foo.["bar"].[138]?.baz[1:]`) + str := `.foo.["bar"].[138]?.baz[1:]` + sel, err := Parse(str) require.NoError(t, err) printSegments(sel) + require.Equal(t, str, sel.String()) require.Equal(t, 7, len(sel)) require.False(t, sel[0].Identity()) require.False(t, sel[0].Optional()) @@ -237,3 +243,189 @@ func printSegments(s Selector) { fmt.Printf("%d: %s\n", i, seg.String()) } } + +func TestSelect(t *testing.T) { + type name struct { + First string + Middle *string + Last string + } + type interest struct { + Name string + Outdoor bool + Experience int + } + type user struct { + Name name + Age int + Nationalities []string + Interests []interest + } + + ts, err := ipld.LoadSchemaBytes([]byte(` + type User struct { + name Name + age Int + nationalities [String] + interests [Interest] + } + type Name struct { + first String + middle optional String + last String + } + type Interest struct { + name String + outdoor Bool + experience Int + } + `)) + require.NoError(t, err) + typ := ts.TypeByName("User") + + am := "Joan" + alice := user{ + Name: name{First: "Alice", Middle: &am, Last: "Wonderland"}, + Age: 24, + Nationalities: []string{"British"}, + Interests: []interest{ + {Name: "Cycling", Outdoor: true, Experience: 4}, + {Name: "Chess", Outdoor: false, Experience: 2}, + }, + } + bob := user{ + Name: name{First: "Bob", Last: "Builder"}, + Age: 35, + Nationalities: []string{"Canadian", "South African"}, + Interests: []interest{ + {Name: "Snowboarding", Outdoor: true, Experience: 8}, + {Name: "Reading", Outdoor: false, Experience: 25}, + }, + } + + anode := bindnode.Wrap(&alice, typ) + bnode := bindnode.Wrap(&bob, typ) + + t.Run("identity", func(t *testing.T) { + sel, err := Parse(".") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.NoError(t, err) + require.NotEmpty(t, one) + require.Empty(t, many) + + fmt.Println(printer.Sprint(one)) + + age := must.Int(must.Node(one.LookupByString("age"))) + require.Equal(t, int64(alice.Age), age) + }) + + t.Run("nested property", func(t *testing.T) { + sel, err := Parse(".name.first") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.NoError(t, err) + require.NotEmpty(t, one) + require.Empty(t, many) + + fmt.Println(printer.Sprint(one)) + + name := must.String(one) + require.Equal(t, alice.Name.First, name) + + one, many, err = Select(sel, bnode) + require.NoError(t, err) + require.NotEmpty(t, one) + require.Empty(t, many) + + fmt.Println(printer.Sprint(one)) + + name = must.String(one) + require.Equal(t, bob.Name.First, name) + }) + + t.Run("optional nested property", func(t *testing.T) { + sel, err := Parse(".name.middle?") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.NoError(t, err) + require.NotEmpty(t, one) + require.Empty(t, many) + + fmt.Println(printer.Sprint(one)) + + name := must.String(one) + require.Equal(t, *alice.Name.Middle, name) + + one, many, err = Select(sel, bnode) + require.NoError(t, err) + require.Empty(t, one) + require.Empty(t, many) + }) + + t.Run("not exists", func(t *testing.T) { + sel, err := Parse(".name.foo") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.Error(t, err) + require.Empty(t, one) + require.Empty(t, many) + + fmt.Println(err) + + if _, ok := err.(ResolutionError); !ok { + t.Fatalf("error was not a resolution error") + } + }) + + t.Run("optional not exists", func(t *testing.T) { + sel, err := Parse(".name.foo?") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.NoError(t, err) + require.Empty(t, one) + require.Empty(t, many) + }) + + t.Run("iterator", func(t *testing.T) { + sel, err := Parse(".interests[]") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.NoError(t, err) + require.Empty(t, one) + require.NotEmpty(t, many) + + for _, n := range many { + fmt.Println(printer.Sprint(n)) + } + + iname := must.String(must.Node(many[0].LookupByString("name"))) + require.Equal(t, alice.Interests[0].Name, iname) + + iname = must.String(must.Node(many[1].LookupByString("name"))) + require.Equal(t, alice.Interests[1].Name, iname) + }) + + t.Run("map iterator", func(t *testing.T) { + sel, err := Parse(".interests[0][]") + require.NoError(t, err) + + one, many, err := Select(sel, anode) + require.NoError(t, err) + require.Empty(t, one) + require.NotEmpty(t, many) + + for _, n := range many { + fmt.Println(printer.Sprint(n)) + } + + require.Equal(t, alice.Interests[0].Name, must.String(many[0])) + require.Equal(t, alice.Interests[0].Experience, int(must.Int(many[2]))) + }) +}