From a9f59a50811668fe201a5e9cd7e0fbdf65684ed3 Mon Sep 17 00:00:00 2001 From: Alan Shaw Date: Mon, 19 Aug 2024 23:16:36 +0200 Subject: [PATCH] feat: add policy implementation --- literal/literal.go | 124 +++++++++++++++++ match.go | 242 ++++++++++++++++++++++++++++++++ match_test.go | 283 ++++++++++++++++++++++++++++++++++++++ selector/selector.go | 110 +++++++++++++++ selector/selector_test.go | 134 ++++++++++++++++++ statement.go | 204 +++++++++++++++++++++++++++ 6 files changed, 1097 insertions(+) create mode 100644 literal/literal.go create mode 100644 match.go create mode 100644 match_test.go create mode 100644 selector/selector.go create mode 100644 selector/selector_test.go create mode 100644 statement.go diff --git a/literal/literal.go b/literal/literal.go new file mode 100644 index 0000000..61d949e --- /dev/null +++ b/literal/literal.go @@ -0,0 +1,124 @@ +package literal + +import ( + "fmt" + + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" +) + +var ErrType = fmt.Errorf("literal is not this type") + +const ( + Kind_IPLD = "ipld" + Kind_Int = "int" + Kind_Float = "float" + Kind_String = "string" +) + +type Literal interface { + Kind() string // ipld | integer | float | string + AsNode() (ipld.Node, error) + AsInt() (int64, error) + AsFloat() (float64, error) + AsString() (string, error) +} + +type literal struct{} + +func (l literal) AsFloat() (float64, error) { + return 0, ErrType +} + +func (l literal) AsInt() (int64, error) { + return 0, ErrType +} + +func (l literal) AsNode() (datamodel.Node, error) { + return nil, ErrType +} + +func (l literal) AsString() (string, error) { + return "", ErrType +} + +type node struct { + literal + value ipld.Node +} + +func (l node) AsNode() (datamodel.Node, error) { + return l.value, nil +} + +func (l node) Kind() string { + return Kind_IPLD +} + +func Node(n ipld.Node) Literal { + return node{value: n} +} + +func Link(cid ipld.Link) Literal { + nb := basicnode.Prototype.Link.NewBuilder() + nb.AssignLink(cid) + return node{value: nb.Build()} +} + +func Bool(val bool) Literal { + nb := basicnode.Prototype.Bool.NewBuilder() + nb.AssignBool(val) + return node{value: nb.Build()} +} + +type nint struct { + literal + value int64 +} + +func (l nint) AsInt() (int64, error) { + return l.value, nil +} + +func (l nint) Kind() string { + return Kind_Int +} + +func Int(num int64) Literal { + return nint{value: num} +} + +type nfloat struct { + literal + value float64 +} + +func (l nfloat) AsFloat() (float64, error) { + return l.value, nil +} + +func (l nfloat) Kind() string { + return Kind_Float +} + +func Float(num float64) Literal { + return nfloat{value: num} +} + +type str struct { + literal + value string +} + +func (l str) AsString() (string, error) { + return l.value, nil +} + +func (l str) Kind() string { + return Kind_String +} + +func String(s string) Literal { + return str{value: s} +} diff --git a/match.go b/match.go new file mode 100644 index 0000000..63266a2 --- /dev/null +++ b/match.go @@ -0,0 +1,242 @@ +package policy + +import ( + "cmp" + "fmt" + + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/storacha-network/go-ucanto/core/policy/literal" + "github.com/storacha-network/go-ucanto/core/policy/selector" +) + +// Match determines if the IPLD node matches the policy document. +func Match(policy Policy, node ipld.Node) (bool, error) { + for _, stmt := range policy { + ok, err := matchStatement(stmt, node) + if err != nil || !ok { + return ok, err + } + } + return true, nil +} + +func matchStatement(statement Statement, node ipld.Node) (bool, error) { + switch statement.Kind() { + case Kind_Equal: + if s, ok := statement.(EqualityStatement); ok { + n, err := selectNode(s.Selector(), node) + if err != nil { + if _, ok := err.(datamodel.ErrNotExists); ok { + return false, nil + } + return false, fmt.Errorf("selecting node: %w", err) + } + return isDeepEqual(s.Value(), n) + } + case Kind_GreaterThan: + if s, ok := statement.(InequalityStatement); ok { + n, err := selectNode(s.Selector(), node) + if err != nil { + if _, ok := err.(datamodel.ErrNotExists); ok { + return false, nil + } + return false, fmt.Errorf("selecting node: %w", err) + } + return isOrdered(s, n, gt) + } + case Kind_GreaterThanOrEqual: + if s, ok := statement.(InequalityStatement); ok { + n, err := selectNode(s.Selector(), node) + if err != nil { + if _, ok := err.(datamodel.ErrNotExists); ok { + return false, nil + } + return false, fmt.Errorf("selecting node: %w", err) + } + return isOrdered(s, n, gte) + } + case Kind_LessThan: + if s, ok := statement.(InequalityStatement); ok { + n, err := selectNode(s.Selector(), node) + if err != nil { + if _, ok := err.(datamodel.ErrNotExists); ok { + return false, nil + } + return false, fmt.Errorf("selecting node: %w", err) + } + return isOrdered(s, n, lt) + } + case Kind_LessThanOrEqual: + if s, ok := statement.(InequalityStatement); ok { + n, err := selectNode(s.Selector(), node) + if err != nil { + if _, ok := err.(datamodel.ErrNotExists); ok { + return false, nil + } + return false, fmt.Errorf("selecting node: %w", err) + } + return isOrdered(s, n, lte) + } + case Kind_Negation: + if s, ok := statement.(NegationStatement); ok { + r, err := matchStatement(s.Value(), node) + if err != nil { + return false, err + } + return !r, err + } + case Kind_Conjunction: + if s, ok := statement.(ConjunctionStatement); ok { + for _, cs := range s.Value() { + r, err := matchStatement(cs, node) + if err != nil { + return false, err + } + if !r { + return false, nil + } + } + return true, nil + } + case Kind_Disjunction: + if s, ok := statement.(DisjunctionStatement); ok { + for _, cs := range s.Value() { + r, err := matchStatement(cs, node) + if err != nil { + return false, err + } + if r { + return true, nil + } + } + return false, nil + } + case Kind_Wildcard: + case Kind_Universal: + case Kind_Existential: + } + return false, fmt.Errorf("statement kind not implemented: %s", statement.Kind()) +} + +func selectNode(sel selector.Selector, node ipld.Node) (child ipld.Node, err error) { + if sel.Identity() { + child = node + } else if sel.Field() != "" { + child, err = node.LookupByString(sel.Field()) + } else { + child, err = node.LookupByIndex(int64(sel.Index())) + } + return +} + +func isOrdered(stmt InequalityStatement, node ipld.Node, satisfies func(order int) bool) (bool, error) { + if stmt.Value().Kind() == literal.Kind_Int && node.Kind() == ipld.Kind_Int { + a, err := node.AsInt() + if err != nil { + return false, fmt.Errorf("extracting node int: %w", err) + } + b, err := stmt.Value().AsInt() + if err != nil { + return false, fmt.Errorf("extracting selector int: %w", err) + } + return satisfies(cmp.Compare(a, b)), nil + } + + if stmt.Value().Kind() == literal.Kind_Float && node.Kind() == ipld.Kind_Float { + a, err := node.AsFloat() + if err != nil { + return false, fmt.Errorf("extracting node float: %w", err) + } + b, err := stmt.Value().AsFloat() + if err != nil { + return false, fmt.Errorf("extracting selector float: %w", err) + } + return satisfies(cmp.Compare(a, b)), nil + } + + return false, fmt.Errorf("selector type %s is not compatible with node type %s: kind mismatch: need int or float", stmt.Value().Kind(), node.Kind()) +} + +func isDeepEqual(value literal.Literal, node ipld.Node) (bool, error) { + switch value.Kind() { + case literal.Kind_String: + if node.Kind() != ipld.Kind_String { + return false, nil + } + a, err := node.AsString() + if err != nil { + return false, fmt.Errorf("extracting node string: %w", err) + } + b, err := value.AsString() + if err != nil { + return false, fmt.Errorf("extracting selector string: %w", err) + } + return a == b, nil + case literal.Kind_Int: + if node.Kind() != ipld.Kind_Int { + return false, nil + } + a, err := node.AsInt() + if err != nil { + return false, fmt.Errorf("extracting node int: %w", err) + } + b, err := value.AsInt() + if err != nil { + return false, fmt.Errorf("extracting selector int: %w", err) + } + return a == b, nil + case literal.Kind_Float: + if node.Kind() != ipld.Kind_Float { + return false, nil + } + a, err := node.AsFloat() + if err != nil { + return false, fmt.Errorf("extracting node float: %w", err) + } + b, err := value.AsFloat() + if err != nil { + return false, fmt.Errorf("extracting selector float: %w", err) + } + return a == b, nil + case literal.Kind_IPLD: + v, err := value.AsNode() + if err != nil { + return false, fmt.Errorf("extracting selector node: %w", err) + } + if v.Kind() != node.Kind() { + return false, nil + } + // TODO: should be easy enough to do the basic types, map, struct and list + // might be harder. + switch v.Kind() { + case ipld.Kind_Bool: + a, err := node.AsBool() + if err != nil { + return false, fmt.Errorf("extracting node boolean: %w", err) + } + b, err := v.AsBool() + if err != nil { + return false, fmt.Errorf("extracting selector node boolean: %w", err) + } + return a == b, nil + case ipld.Kind_Link: + a, err := node.AsLink() + if err != nil { + return false, fmt.Errorf("extracting node link: %w", err) + } + b, err := v.AsLink() + if err != nil { + return false, fmt.Errorf("extracting selector node link: %w", err) + } + return a.Binary() == b.Binary(), nil + } + return false, fmt.Errorf("unsupported IPLD kind: %s", v.Kind()) + } + return false, fmt.Errorf("unknown literal kind: %s", value.Kind()) +} + +func gt(order int) bool { return order == 1 } +func gte(order int) bool { return order == 0 || order == 1 } +func lt(order int) bool { return order == -1 } +func lte(order int) bool { return order == 0 || order == -1 } diff --git a/match_test.go b/match_test.go new file mode 100644 index 0000000..5d9f1f7 --- /dev/null +++ b/match_test.go @@ -0,0 +1,283 @@ +package policy + +import ( + "testing" + + "github.com/ipfs/go-cid" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/storacha-network/go-ucanto/core/policy/literal" + "github.com/storacha-network/go-ucanto/core/policy/selector" + "github.com/stretchr/testify/require" +) + +func TestMatch(t *testing.T) { + t.Run("equality string", func(t *testing.T) { + np := basicnode.Prototype.String + nb := np.NewBuilder() + nb.AssignString("test") + nd := nb.Build() + + pol := Policy{Equal(selector.MustParse("."), literal.String("test"))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.String("test2"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.Int(138))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("equality int", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{Equal(selector.MustParse("."), literal.Int(138))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.Int(1138))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.String("138"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("equality float", func(t *testing.T) { + np := basicnode.Prototype.Float + nb := np.NewBuilder() + nb.AssignFloat(1.138) + nd := nb.Build() + + pol := Policy{Equal(selector.MustParse("."), literal.Float(1.138))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.Float(11.38))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.String("138"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("equality IPLD Link", func(t *testing.T) { + l0 := cidlink.Link{Cid: cid.MustParse("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq")} + l1 := cidlink.Link{Cid: cid.MustParse("bafkreifau35r7vi37tvbvfy3hdwvgb4tlflqf7zcdzeujqcjk3rsphiwte")} + + np := basicnode.Prototype.Link + nb := np.NewBuilder() + nb.AssignLink(l0) + nd := nb.Build() + + pol := Policy{Equal(selector.MustParse("."), literal.Link(l0))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.Link(l1))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + + pol = Policy{Equal(selector.MustParse("."), literal.String("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("equality string in map", func(t *testing.T) { + np := basicnode.Prototype.Map + nb := np.NewBuilder() + ma, _ := nb.BeginMap(1) + ma.AssembleKey().AssignString("foo") + ma.AssembleValue().AssignString("bar") + ma.Finish() + nd := nb.Build() + + pol := Policy{Equal(selector.MustParse(".foo"), literal.String("bar"))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse(".[\"foo\"]"), literal.String("bar"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse(".foo"), literal.String("baz"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + + pol = Policy{Equal(selector.MustParse(".foobar"), literal.String("bar"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("equality string in list", func(t *testing.T) { + np := basicnode.Prototype.List + nb := np.NewBuilder() + la, _ := nb.BeginList(1) + la.AssembleValue().AssignString("foo") + la.Finish() + nd := nb.Build() + + pol := Policy{Equal(selector.MustParse(".[0]"), literal.String("foo"))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Equal(selector.MustParse(".[1]"), literal.String("foo"))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("inequality gt int", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{GreaterThan(selector.MustParse("."), literal.Int(1))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("inequality gte int", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(1))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(138))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("inequality gt float", func(t *testing.T) { + np := basicnode.Prototype.Float + nb := np.NewBuilder() + nb.AssignFloat(1.38) + nd := nb.Build() + + pol := Policy{GreaterThan(selector.MustParse("."), literal.Float(1))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("inequality gte float", func(t *testing.T) { + np := basicnode.Prototype.Float + nb := np.NewBuilder() + nb.AssignFloat(1.38) + nd := nb.Build() + + pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1.38))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + }) + + t.Run("negation", func(t *testing.T) { + np := basicnode.Prototype.Bool + nb := np.NewBuilder() + nb.AssignBool(false) + nd := nb.Build() + + pol := Policy{Not(Equal(selector.MustParse("."), literal.Bool(true)))} + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{Not(Equal(selector.MustParse("."), literal.Bool(false)))} + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("conjunction", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{ + And( + GreaterThan(selector.MustParse("."), literal.Int(1)), + LessThan(selector.MustParse("."), literal.Int(1138)), + ), + } + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{ + And( + GreaterThan(selector.MustParse("."), literal.Int(1)), + Equal(selector.MustParse("."), literal.Int(1138)), + ), + } + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("disjunction", func(t *testing.T) { + np := basicnode.Prototype.Int + nb := np.NewBuilder() + nb.AssignInt(138) + nd := nb.Build() + + pol := Policy{ + Or( + GreaterThan(selector.MustParse("."), literal.Int(138)), + LessThan(selector.MustParse("."), literal.Int(1138)), + ), + } + ok, err := Match(pol, nd) + require.NoError(t, err) + require.True(t, ok) + + pol = Policy{ + Or( + GreaterThan(selector.MustParse("."), literal.Int(138)), + Equal(selector.MustParse("."), literal.Int(1138)), + ), + } + ok, err = Match(pol, nd) + require.NoError(t, err) + require.False(t, ok) + }) +} diff --git a/selector/selector.go b/selector/selector.go new file mode 100644 index 0000000..c00f840 --- /dev/null +++ b/selector/selector.go @@ -0,0 +1,110 @@ +package selector + +import ( + "fmt" + "strconv" + "strings" +) + +// Selector describes a UCAN policy selector, as specified here: +// https://github.com/ucan-wg/delegation/blob/4094d5878b58f5d35055a3b93fccda0b8329ebae/README.md#selectors +type Selector interface { + // Identity flags that this selector is the identity selector. + Identity() bool + // Optional flags that this selector is optional. + Optional() bool + // Field is the name of a field in a struct/map. + Field() string + // Index is an index of a slice. + Index() int + // String returns the selector's string representation. + String() string +} + +type selector struct { + str string + identity bool + optional bool + field string + index int +} + +func (s selector) Field() string { + return s.field +} + +func (s selector) Identity() bool { + return s.identity +} + +func (s selector) Index() int { + return s.index +} + +func (s selector) Optional() bool { + return s.optional +} + +func (s selector) String() string { + return s.str +} + +// TODO: probably regex or better parser +func Parse(sel string) (Selector, error) { + s := sel + if s == "." { + return selector{sel, true, false, "", 0}, nil + } + + optional := strings.HasSuffix(s, "?") + if optional { + s = s[0 : len(s)-1] + } + + dotted := strings.HasPrefix(s, ".") + if dotted { + s = s[1:] + } + + // collection values + if s == "[]" { + return nil, fmt.Errorf("unsupported selector: %s", sel) + } + + if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") { + s = s[1 : len(s)-1] + + // explicit field selector + if strings.HasPrefix(s, "\"") && strings.HasSuffix(s, "\"") { + return selector{sel, false, optional, s[1 : len(s)-1], 0}, nil + } + + // collection range + if strings.Contains(s, ":") { + return nil, fmt.Errorf("unsupported selector: %s", sel) + } + + // index selector + idx, err := strconv.Atoi(s) + if err != nil { + return nil, fmt.Errorf("parsing index selector value: %s", err) + } + + return selector{sel, false, optional, "", idx}, nil + } + + if !dotted { + return nil, fmt.Errorf("invalid selector: %s", sel) + } + + // dotted field selector + return selector{sel, false, optional, s, 0}, nil +} + +func MustParse(sel string) Selector { + s, err := Parse(sel) + if err != nil { + panic(err) + } + return s +} diff --git a/selector/selector_test.go b/selector/selector_test.go new file mode 100644 index 0000000..d8fab0b --- /dev/null +++ b/selector/selector_test.go @@ -0,0 +1,134 @@ +package selector + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + t.Run("identity", func(t *testing.T) { + sel, err := Parse(".") + require.NoError(t, err) + require.True(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Empty(t, sel.Field()) + require.Empty(t, sel.Index()) + }) + + t.Run("dotted field", func(t *testing.T) { + sel, err := Parse(".foo") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Equal(t, sel.Field(), "foo") + require.Empty(t, sel.Index()) + }) + + t.Run("dotted explicit field", func(t *testing.T) { + sel, err := Parse(".[\"foo\"]") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Equal(t, sel.Field(), "foo") + require.Empty(t, sel.Index()) + }) + + t.Run("dotted index", func(t *testing.T) { + sel, err := Parse(".[138]") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Empty(t, sel.Field()) + require.Equal(t, sel.Index(), 138) + }) + + t.Run("explicit field", func(t *testing.T) { + sel, err := Parse("[\"foo\"]") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Equal(t, sel.Field(), "foo") + require.Empty(t, sel.Index()) + }) + + t.Run("index", func(t *testing.T) { + sel, err := Parse("[138]") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Empty(t, sel.Field()) + require.Equal(t, sel.Index(), 138) + }) + + t.Run("negative index", func(t *testing.T) { + sel, err := Parse("[-138]") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.False(t, sel.Optional()) + require.Empty(t, sel.Field()) + require.Equal(t, sel.Index(), -138) + }) + + t.Run("optional dotted field", func(t *testing.T) { + sel, err := Parse(".foo?") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.True(t, sel.Optional()) + require.Equal(t, sel.Field(), "foo") + require.Empty(t, sel.Index()) + }) + + t.Run("optional dotted explicit field", func(t *testing.T) { + sel, err := Parse(".[\"foo\"]?") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.True(t, sel.Optional()) + require.Equal(t, sel.Field(), "foo") + require.Empty(t, sel.Index()) + }) + + t.Run("optional dotted index", func(t *testing.T) { + sel, err := Parse(".[138]?") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.True(t, sel.Optional()) + require.Empty(t, sel.Field()) + require.Equal(t, sel.Index(), 138) + }) + + t.Run("optional explicit field", func(t *testing.T) { + sel, err := Parse("[\"foo\"]?") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.True(t, sel.Optional()) + require.Equal(t, sel.Field(), "foo") + require.Empty(t, sel.Index()) + }) + + t.Run("optional index", func(t *testing.T) { + sel, err := Parse("[138]?") + require.NoError(t, err) + require.False(t, sel.Identity()) + require.True(t, sel.Optional()) + require.Empty(t, sel.Field()) + require.Equal(t, sel.Index(), 138) + }) + + t.Run("non dotted", func(t *testing.T) { + _, err := Parse("foo") + if err == nil { + t.Fatalf("expected error parsing selector") + } + fmt.Println(err) + }) + + t.Run("non quoted", func(t *testing.T) { + _, err := Parse(".[foo]") + if err == nil { + t.Fatalf("expected error parsing selector") + } + fmt.Println(err) + }) +} diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..cbedc9f --- /dev/null +++ b/statement.go @@ -0,0 +1,204 @@ +package policy + +// https://github.com/ucan-wg/delegation/blob/4094d5878b58f5d35055a3b93fccda0b8329ebae/README.md#policy + +import ( + "github.com/storacha-network/go-ucanto/core/policy/literal" + "github.com/storacha-network/go-ucanto/core/policy/selector" +) + +const ( + Kind_Equal = "==" + Kind_GreaterThan = ">" + Kind_GreaterThanOrEqual = ">=" + Kind_LessThan = "<" + Kind_LessThanOrEqual = "<=" + Kind_Negation = "not" + Kind_Conjunction = "and" + Kind_Disjunction = "or" + Kind_Wildcard = "like" + Kind_Universal = "all" + Kind_Existential = "any" +) + +type Policy = []Statement + +type Statement interface { + Kind() string +} + +type EqualityStatement interface { + Statement + Selector() selector.Selector + Value() literal.Literal +} + +type InequalityStatement interface { + Statement + Selector() selector.Selector + Value() literal.Literal +} + +type WildcardStatement interface { + Statement + Selector() selector.Selector + Value() string +} + +type ConnectiveStatement interface { + Statement +} + +type NegationStatement interface { + ConnectiveStatement + Value() Statement +} + +type ConjunctionStatement interface { + ConnectiveStatement + Value() []Statement +} + +type DisjunctionStatement interface { + ConnectiveStatement + Value() []Statement +} + +type QuantifierStatement interface { + Statement + Selector() selector.Selector + Value() Policy +} + +type equality struct { + kind string + selector selector.Selector + value literal.Literal +} + +func (e equality) Kind() string { + return e.kind +} + +func (e equality) Value() literal.Literal { + return e.value +} + +func (e equality) Selector() selector.Selector { + return e.selector +} + +func Equal(selector selector.Selector, value literal.Literal) EqualityStatement { + return equality{Kind_Equal, selector, value} +} + +func GreaterThan(selector selector.Selector, value literal.Literal) InequalityStatement { + return equality{Kind_GreaterThan, selector, value} +} + +func GreaterThanOrEqual(selector selector.Selector, value literal.Literal) InequalityStatement { + return equality{Kind_GreaterThanOrEqual, selector, value} +} + +func LessThan(selector selector.Selector, value literal.Literal) InequalityStatement { + return equality{Kind_LessThan, selector, value} +} + +func LessThanOrEqual(selector selector.Selector, value literal.Literal) InequalityStatement { + return equality{Kind_LessThanOrEqual, selector, value} +} + +type negation struct { + statement Statement +} + +func (n negation) Kind() string { + return Kind_Negation +} + +func (n negation) Value() Statement { + return n.statement +} + +func Not(stmt Statement) NegationStatement { + return negation{stmt} +} + +type conjunction struct { + statements []Statement +} + +func (n conjunction) Kind() string { + return Kind_Conjunction +} + +func (n conjunction) Value() []Statement { + return n.statements +} + +func And(stmts ...Statement) ConjunctionStatement { + return conjunction{stmts} +} + +type disjunction struct { + statements []Statement +} + +func (n disjunction) Kind() string { + return Kind_Disjunction +} + +func (n disjunction) Value() []Statement { + return n.statements +} + +func Or(stmts ...Statement) DisjunctionStatement { + return disjunction{stmts} +} + +type wildcard struct { + selector selector.Selector + pattern string +} + +func (n wildcard) Kind() string { + return Kind_Wildcard +} + +func (n wildcard) Selector() selector.Selector { + return n.selector +} + +func (n wildcard) Value() string { + return n.pattern +} + +func Like(selector selector.Selector, pattern string) WildcardStatement { + return wildcard{selector, pattern} +} + +type quantifier struct { + kind string + selector selector.Selector + policy Policy +} + +func (n quantifier) Kind() string { + return n.kind +} + +func (n quantifier) Selector() selector.Selector { + return n.selector +} + +func (n quantifier) Value() Policy { + return n.policy +} + +func All(selector selector.Selector, policy Policy) QuantifierStatement { + return quantifier{Kind_Universal, selector, policy} +} + +func Any(selector selector.Selector, policy Policy) QuantifierStatement { + return quantifier{Kind_Existential, selector, policy} +}