From 51e8d5ce044680670bb6920c16517d9805b5aad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Mur=C3=A9?= Date: Mon, 14 Oct 2024 20:09:21 +0200 Subject: [PATCH] policy: fluent construction --- pkg/container/serial_test.go | 10 +-- pkg/policy/ipld.go | 6 +- pkg/policy/literal/literal.go | 22 ++--- pkg/policy/match_test.go | 133 +++++++++++---------------- pkg/policy/policy.go | 144 ++++++++++++++++++++++-------- pkg/policy/policy_test.go | 60 +++++++++++++ pkg/policy/selector/parsing.go | 5 ++ pkg/policy/selector/selector.go | 2 - token/delegation/examples_test.go | 33 ++++--- 9 files changed, 260 insertions(+), 155 deletions(-) create mode 100644 pkg/policy/policy_test.go diff --git a/pkg/container/serial_test.go b/pkg/container/serial_test.go index 6552f88..e2d3d94 100644 --- a/pkg/container/serial_test.go +++ b/pkg/container/serial_test.go @@ -17,7 +17,6 @@ import ( "github.com/ucan-wg/go-ucan/pkg/command" "github.com/ucan-wg/go-ucan/pkg/policy" "github.com/ucan-wg/go-ucan/pkg/policy/literal" - "github.com/ucan-wg/go-ucan/pkg/policy/selector" "github.com/ucan-wg/go-ucan/token/delegation" ) @@ -159,10 +158,11 @@ func randToken() (*delegation.Token, cid.Cid, []byte) { priv, iss := randDID() _, aud := randDID() cmd := command.New("foo", "bar") - pol := policy.Policy{policy.All( - selector.MustParse(".[]"), - policy.GreaterThan(selector.MustParse(".value"), literal.Int(2)), - )} + pol := policy.MustConstruct( + policy.All(".[]", + policy.GreaterThan(".value", literal.Int(2)), + ), + ) opts := []delegation.Option{ delegation.WithExpiration(time.Now().Add(time.Hour)), diff --git a/pkg/policy/ipld.go b/pkg/policy/ipld.go index e3c67d1..9d52d4d 100644 --- a/pkg/policy/ipld.go +++ b/pkg/policy/ipld.go @@ -61,7 +61,7 @@ func statementFromIPLD(path string, node datamodel.Node) (Statement, error) { if err != nil { return nil, err } - return Not(statement), nil + return negation{statement: statement}, nil case KindAnd, KindOr: arg2, _ := node.LookupByIndex(1) @@ -93,11 +93,11 @@ func statementFromIPLD(path string, node datamodel.Node) (Statement, error) { if pattern.Kind() != datamodel.Kind_String { return nil, ErrNotAString(combinePath(path, op, 2)) } - res, err := Like(sel, must.String(pattern)) + g, err := parseGlob(must.String(pattern)) if err != nil { return nil, ErrInvalidPattern(combinePath(path, op, 2), err) } - return res, nil + return wildcard{selector: sel, pattern: g}, nil case KindAll, KindAny: sel, err := arg2AsSelector(op) diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index d5fe54e..23e4e59 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -1,20 +1,12 @@ package literal import ( + "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" ) -func Node(n ipld.Node) ipld.Node { - return n -} - -func Link(cid ipld.Link) ipld.Node { - nb := basicnode.Prototype.Link.NewBuilder() - nb.AssignLink(cid) - return nb.Build() -} - func Bool(val bool) ipld.Node { nb := basicnode.Prototype.Bool.NewBuilder() nb.AssignBool(val) @@ -45,6 +37,16 @@ func Bytes(val []byte) ipld.Node { return nb.Build() } +func Link(link ipld.Link) ipld.Node { + nb := basicnode.Prototype.Link.NewBuilder() + nb.AssignLink(link) + return nb.Build() +} + +func LinkCid(cid cid.Cid) ipld.Node { + return Link(cidlink.Link{Cid: cid}) +} + func Null() ipld.Node { nb := basicnode.Prototype.Any.NewBuilder() nb.AssignNull() diff --git a/pkg/policy/match_test.go b/pkg/policy/match_test.go index 7daf72e..3705fd7 100644 --- a/pkg/policy/match_test.go +++ b/pkg/policy/match_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" "github.com/ucan-wg/go-ucan/pkg/policy/literal" - "github.com/ucan-wg/go-ucan/pkg/policy/selector" ) func TestMatch(t *testing.T) { @@ -23,15 +22,15 @@ func TestMatch(t *testing.T) { nb.AssignString("test") nd := nb.Build() - pol := Policy{Equal(selector.MustParse("."), literal.String("test"))} + pol := MustConstruct(Equal(".", literal.String("test"))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.String("test2"))} + pol = MustConstruct(Equal(".", literal.String("test2"))) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.Int(138))} + pol = MustConstruct(Equal(".", literal.Int(138))) ok = Match(pol, nd) require.False(t, ok) }) @@ -42,15 +41,15 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{Equal(selector.MustParse("."), literal.Int(138))} + pol := MustConstruct(Equal(".", literal.Int(138))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.Int(1138))} + pol = MustConstruct(Equal(".", literal.Int(1138))) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.String("138"))} + pol = MustConstruct(Equal(".", literal.String("138"))) ok = Match(pol, nd) require.False(t, ok) }) @@ -61,15 +60,15 @@ func TestMatch(t *testing.T) { nb.AssignFloat(1.138) nd := nb.Build() - pol := Policy{Equal(selector.MustParse("."), literal.Float(1.138))} + pol := MustConstruct(Equal(".", literal.Float(1.138))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.Float(11.38))} + pol = MustConstruct(Equal(".", literal.Float(11.38))) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.String("138"))} + pol = MustConstruct(Equal(".", literal.String("138"))) ok = Match(pol, nd) require.False(t, ok) }) @@ -83,15 +82,15 @@ func TestMatch(t *testing.T) { nb.AssignLink(l0) nd := nb.Build() - pol := Policy{Equal(selector.MustParse("."), literal.Link(l0))} + pol := MustConstruct(Equal(".", literal.Link(l0))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.Link(l1))} + pol = MustConstruct(Equal(".", literal.Link(l1))) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{Equal(selector.MustParse("."), literal.String("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq"))} + pol = MustConstruct(Equal(".", literal.String("bafybeif4owy5gno5lwnixqm52rwqfodklf76hsetxdhffuxnplvijskzqq"))) ok = Match(pol, nd) require.False(t, ok) }) @@ -105,19 +104,19 @@ func TestMatch(t *testing.T) { ma.Finish() nd := nb.Build() - pol := Policy{Equal(selector.MustParse(".foo"), literal.String("bar"))} + pol := MustConstruct(Equal(".foo", literal.String("bar"))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse(".[\"foo\"]"), literal.String("bar"))} + pol = MustConstruct(Equal(".[\"foo\"]", literal.String("bar"))) ok = Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse(".foo"), literal.String("baz"))} + pol = MustConstruct(Equal(".foo", literal.String("baz"))) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{Equal(selector.MustParse(".foobar"), literal.String("bar"))} + pol = MustConstruct(Equal(".foobar", literal.String("bar"))) ok = Match(pol, nd) require.False(t, ok) }) @@ -130,11 +129,11 @@ func TestMatch(t *testing.T) { la.Finish() nd := nb.Build() - pol := Policy{Equal(selector.MustParse(".[0]"), literal.String("foo"))} + pol := MustConstruct(Equal(".[0]", literal.String("foo"))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Equal(selector.MustParse(".[1]"), literal.String("foo"))} + pol = MustConstruct(Equal(".[1]", literal.String("foo"))) ok = Match(pol, nd) require.False(t, ok) }) @@ -147,7 +146,7 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{GreaterThan(selector.MustParse("."), literal.Int(1))} + pol := MustConstruct(GreaterThan(".", literal.Int(1))) ok := Match(pol, nd) require.True(t, ok) }) @@ -158,11 +157,11 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(1))} + pol := MustConstruct(GreaterThanOrEqual(".", literal.Int(1))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Int(138))} + pol = MustConstruct(GreaterThanOrEqual(".", literal.Int(138))) ok = Match(pol, nd) require.True(t, ok) }) @@ -173,7 +172,7 @@ func TestMatch(t *testing.T) { nb.AssignFloat(1.38) nd := nb.Build() - pol := Policy{GreaterThan(selector.MustParse("."), literal.Float(1))} + pol := MustConstruct(GreaterThan(".", literal.Float(1))) ok := Match(pol, nd) require.True(t, ok) }) @@ -184,11 +183,11 @@ func TestMatch(t *testing.T) { nb.AssignFloat(1.38) nd := nb.Build() - pol := Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1))} + pol := MustConstruct(GreaterThanOrEqual(".", literal.Float(1))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{GreaterThanOrEqual(selector.MustParse("."), literal.Float(1.38))} + pol = MustConstruct(GreaterThanOrEqual(".", literal.Float(1.38))) ok = Match(pol, nd) require.True(t, ok) }) @@ -199,7 +198,7 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{LessThan(selector.MustParse("."), literal.Int(1138))} + pol := MustConstruct(LessThan(".", literal.Int(1138))) ok := Match(pol, nd) require.True(t, ok) }) @@ -210,11 +209,11 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(1138))} + pol := MustConstruct(LessThanOrEqual(".", literal.Int(1138))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{LessThanOrEqual(selector.MustParse("."), literal.Int(138))} + pol = MustConstruct(LessThanOrEqual(".", literal.Int(138))) ok = Match(pol, nd) require.True(t, ok) }) @@ -226,11 +225,11 @@ func TestMatch(t *testing.T) { nb.AssignBool(false) nd := nb.Build() - pol := Policy{Not(Equal(selector.MustParse("."), literal.Bool(true)))} + pol := MustConstruct(Not(Equal(".", literal.Bool(true)))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{Not(Equal(selector.MustParse("."), literal.Bool(false)))} + pol = MustConstruct(Not(Equal(".", literal.Bool(false)))) ok = Match(pol, nd) require.False(t, ok) }) @@ -241,25 +240,25 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{ + pol := MustConstruct( And( - GreaterThan(selector.MustParse("."), literal.Int(1)), - LessThan(selector.MustParse("."), literal.Int(1138)), + GreaterThan(".", literal.Int(1)), + LessThan(".", literal.Int(1138)), ), - } + ) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{ + pol = MustConstruct( And( - GreaterThan(selector.MustParse("."), literal.Int(1)), - Equal(selector.MustParse("."), literal.Int(1138)), + GreaterThan(".", literal.Int(1)), + Equal(".", literal.Int(1138)), ), - } + ) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{And()} + pol = MustConstruct(And()) ok = Match(pol, nd) require.True(t, ok) }) @@ -270,25 +269,25 @@ func TestMatch(t *testing.T) { nb.AssignInt(138) nd := nb.Build() - pol := Policy{ + pol := MustConstruct( Or( - GreaterThan(selector.MustParse("."), literal.Int(138)), - LessThan(selector.MustParse("."), literal.Int(1138)), + GreaterThan(".", literal.Int(138)), + LessThan(".", literal.Int(1138)), ), - } + ) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{ + pol = MustConstruct( Or( - GreaterThan(selector.MustParse("."), literal.Int(138)), - Equal(selector.MustParse("."), literal.Int(1138)), + GreaterThan(".", literal.Int(138)), + Equal(".", literal.Int(1138)), ), - } + ) ok = Match(pol, nd) require.False(t, ok) - pol = Policy{Or()} + pol = MustConstruct(Or()) ok = Match(pol, nd) require.True(t, ok) }) @@ -309,10 +308,7 @@ func TestMatch(t *testing.T) { nb.AssignString(s) nd := nb.Build() - statement, err := Like(selector.MustParse("."), pattern) - require.NoError(t, err) - - pol := Policy{statement} + pol := MustConstruct(Like(".", pattern)) ok := Match(pol, nd) require.True(t, ok) }) @@ -333,10 +329,7 @@ func TestMatch(t *testing.T) { nb.AssignString(s) nd := nb.Build() - statement, err := Like(selector.MustParse("."), pattern) - require.NoError(t, err) - - pol := Policy{statement} + pol := MustConstruct(Like(".", pattern)) ok := Match(pol, nd) require.False(t, ok) }) @@ -367,21 +360,11 @@ func TestMatch(t *testing.T) { la.Finish() nd := nb.Build() - pol := Policy{ - All( - selector.MustParse(".[]"), - GreaterThan(selector.MustParse(".value"), literal.Int(2)), - ), - } + pol := MustConstruct(All(".[]", GreaterThan(".value", literal.Int(2)))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{ - All( - selector.MustParse(".[]"), - GreaterThan(selector.MustParse(".value"), literal.Int(20)), - ), - } + pol = MustConstruct(All(".[]", GreaterThan(".value", literal.Int(20)))) ok = Match(pol, nd) require.False(t, ok) }) @@ -398,21 +381,11 @@ func TestMatch(t *testing.T) { la.Finish() nd := nb.Build() - pol := Policy{ - Any( - selector.MustParse(".[]"), - GreaterThan(selector.MustParse(".value"), literal.Int(60)), - ), - } + pol := MustConstruct(Any(".[]", GreaterThan(".value", literal.Int(60)))) ok := Match(pol, nd) require.True(t, ok) - pol = Policy{ - Any( - selector.MustParse(".[]"), - GreaterThan(selector.MustParse(".value"), literal.Int(100)), - ), - } + pol = MustConstruct(Any(".[]", GreaterThan(".value", literal.Int(100)))) ok = Match(pol, nd) require.False(t, ok) }) diff --git a/pkg/policy/policy.go b/pkg/policy/policy.go index e7807dc..5da2ea9 100644 --- a/pkg/policy/policy.go +++ b/pkg/policy/policy.go @@ -9,7 +9,7 @@ import ( "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/codec/dagjson" - "github.com/ucan-wg/go-ucan/pkg/policy/selector" + selpkg "github.com/ucan-wg/go-ucan/pkg/policy/selector" ) const ( @@ -28,6 +28,24 @@ const ( type Policy []Statement +type Constructor func() (Statement, error) + +func Construct(cstors ...Constructor) (Policy, error) { + stmts, err := assemble(cstors) + if err != nil { + return nil, err + } + return stmts, nil +} + +func MustConstruct(cstors ...Constructor) Policy { + pol, err := Construct(cstors...) + if err != nil { + panic(err) + } + return pol +} + func (p Policy) String() string { if len(p) == 0 { return "[]" @@ -46,7 +64,7 @@ type Statement interface { type equality struct { kind string - selector selector.Selector + selector selpkg.Selector value ipld.Node } @@ -62,24 +80,39 @@ func (e equality) String() string { return fmt.Sprintf(`["%s", "%s", %s]`, e.kind, e.selector, strings.ReplaceAll(string(child), "\n", "\n ")) } -func Equal(selector selector.Selector, value ipld.Node) Statement { - return equality{kind: KindEqual, selector: selector, value: value} +func Equal(selector string, value ipld.Node) Constructor { + return func() (Statement, error) { + sel, err := selpkg.Parse(selector) + return equality{kind: KindEqual, selector: sel, value: value}, err + } } -func GreaterThan(selector selector.Selector, value ipld.Node) Statement { - return equality{kind: KindGreaterThan, selector: selector, value: value} +func GreaterThan(selector string, value ipld.Node) Constructor { + return func() (Statement, error) { + sel, err := selpkg.Parse(selector) + return equality{kind: KindGreaterThan, selector: sel, value: value}, err + } } -func GreaterThanOrEqual(selector selector.Selector, value ipld.Node) Statement { - return equality{kind: KindGreaterThanOrEqual, selector: selector, value: value} +func GreaterThanOrEqual(selector string, value ipld.Node) Constructor { + return func() (Statement, error) { + sel, err := selpkg.Parse(selector) + return equality{kind: KindGreaterThanOrEqual, selector: sel, value: value}, err + } } -func LessThan(selector selector.Selector, value ipld.Node) Statement { - return equality{kind: KindLessThan, selector: selector, value: value} +func LessThan(selector string, value ipld.Node) Constructor { + return func() (Statement, error) { + sel, err := selpkg.Parse(selector) + return equality{kind: KindLessThan, selector: sel, value: value}, err + } } -func LessThanOrEqual(selector selector.Selector, value ipld.Node) Statement { - return equality{kind: KindLessThanOrEqual, selector: selector, value: value} +func LessThanOrEqual(selector string, value ipld.Node) Constructor { + return func() (Statement, error) { + sel, err := selpkg.Parse(selector) + return equality{kind: KindLessThanOrEqual, selector: sel, value: value}, err + } } type negation struct { @@ -95,8 +128,11 @@ func (n negation) String() string { return fmt.Sprintf(`["%s", "%s"]`, n.Kind(), strings.ReplaceAll(child, "\n", "\n ")) } -func Not(stmt Statement) Statement { - return negation{statement: stmt} +func Not(cstor Constructor) Constructor { + return func() (Statement, error) { + stmt, err := cstor() + return negation{statement: stmt}, err + } } type connective struct { @@ -116,16 +152,28 @@ func (c connective) String() string { return fmt.Sprintf("[\"%s\", [\n %s]]\n", c.kind, strings.Join(childs, ",\n ")) } -func And(stmts ...Statement) Statement { - return connective{kind: KindAnd, statements: stmts} +func And(cstors ...Constructor) Constructor { + return func() (Statement, error) { + stmts, err := assemble(cstors) + if err != nil { + return nil, err + } + return connective{kind: KindAnd, statements: stmts}, nil + } } -func Or(stmts ...Statement) Statement { - return connective{kind: KindOr, statements: stmts} +func Or(cstors ...Constructor) Constructor { + return func() (Statement, error) { + stmts, err := assemble(cstors) + if err != nil { + return nil, err + } + return connective{kind: KindOr, statements: stmts}, nil + } } type wildcard struct { - selector selector.Selector + selector selpkg.Selector pattern glob } @@ -137,26 +185,20 @@ func (n wildcard) String() string { return fmt.Sprintf(`["%s", "%s", "%s"]`, n.Kind(), n.selector, n.pattern) } -func Like(selector selector.Selector, pattern string) (Statement, error) { - g, err := parseGlob(pattern) - if err != nil { - return nil, err +func Like(selector string, pattern string) Constructor { + return func() (Statement, error) { + g, err := parseGlob(pattern) + if err != nil { + return nil, err + } + sel, err := selpkg.Parse(selector) + return wildcard{selector: sel, pattern: g}, err } - - return wildcard{selector: selector, pattern: g}, nil -} - -func MustLike(selector selector.Selector, pattern string) Statement { - g, err := Like(selector, pattern) - if err != nil { - panic(err) - } - return g } type quantifier struct { kind string - selector selector.Selector + selector selpkg.Selector statement Statement } @@ -169,10 +211,36 @@ func (n quantifier) String() string { return fmt.Sprintf("[\"%s\", \"%s\",\n %s]", n.Kind(), n.selector, strings.ReplaceAll(child, "\n", "\n ")) } -func All(selector selector.Selector, statement Statement) Statement { - return quantifier{kind: KindAll, selector: selector, statement: statement} +func All(selector string, cstor Constructor) Constructor { + return func() (Statement, error) { + stmt, err := cstor() + if err != nil { + return nil, err + } + sel, err := selpkg.Parse(selector) + return quantifier{kind: KindAll, selector: sel, statement: stmt}, err + } } -func Any(selector selector.Selector, statement Statement) Statement { - return quantifier{kind: KindAny, selector: selector, statement: statement} +func Any(selector string, cstor Constructor) Constructor { + return func() (Statement, error) { + stmt, err := cstor() + if err != nil { + return nil, err + } + sel, err := selpkg.Parse(selector) + return quantifier{kind: KindAny, selector: sel, statement: stmt}, err + } +} + +func assemble(cstors []Constructor) ([]Statement, error) { + stmts := make([]Statement, 0, len(cstors)) + for _, cstor := range cstors { + stmt, err := cstor() + if err != nil { + return nil, err + } + stmts = append(stmts, stmt) + } + return stmts, nil } diff --git a/pkg/policy/policy_test.go b/pkg/policy/policy_test.go new file mode 100644 index 0000000..26ed239 --- /dev/null +++ b/pkg/policy/policy_test.go @@ -0,0 +1,60 @@ +package policy_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ucan-wg/go-ucan/pkg/policy" + "github.com/ucan-wg/go-ucan/pkg/policy/literal" +) + +func ExamplePolicy() { + pol := policy.MustConstruct( + policy.Equal(".status", literal.String("draft")), + policy.All(".reviewer", + policy.Like(".email", "*@example.com"), + ), + policy.Any(".tags", policy.Or( + policy.Equal(".", literal.String("news")), + policy.Equal(".", literal.String("press")), + )), + ) + + fmt.Println(pol) + + // Output: + // [ + // ["==", ".status", "draft"], + // ["all", ".reviewer", + // ["like", ".email", "*@example.com"]], + // ["any", ".tags", + // ["or", [ + // ["==", ".", "news"], + // ["==", ".", "press"]]] + // ] + // ] +} + +func TestConstruct(t *testing.T) { + pol, err := policy.Construct( + policy.Equal(".status", literal.String("draft")), + policy.All(".reviewer", + policy.Like(".email", "*@example.com"), + ), + ) + require.NoError(t, err) + require.NotNil(t, pol) + + // check if errors cascade correctly + pol, err = policy.Construct( + policy.Equal(".status", literal.String("draft")), + policy.All(".reviewer", policy.Or( + policy.Like(".email", "*@example.com"), + policy.Like(".", "\\"), // invalid pattern + )), + ) + require.Error(t, err) + require.Nil(t, pol) +} diff --git a/pkg/policy/selector/parsing.go b/pkg/policy/selector/parsing.go index e47a70a..887be6c 100644 --- a/pkg/policy/selector/parsing.go +++ b/pkg/policy/selector/parsing.go @@ -6,6 +6,8 @@ import ( "strings" ) +var identity = Selector{segment{".", true, false, false, nil, "", 0}} + func Parse(str string) (Selector, error) { if len(str) == 0 { return nil, newParseError("empty selector", str, 0, "") @@ -13,6 +15,9 @@ func Parse(str string) (Selector, error) { if string(str[0]) != "." { return nil, newParseError("selector must start with identity segment '.'", str, 0, string(str[0])) } + if str == "." { + return identity, nil + } col := 0 var sel Selector diff --git a/pkg/policy/selector/selector.go b/pkg/policy/selector/selector.go index e116c08..e4a1401 100644 --- a/pkg/policy/selector/selector.go +++ b/pkg/policy/selector/selector.go @@ -23,8 +23,6 @@ func (s Selector) String() string { return res.String() } -var Identity = MustParse(".") - var ( indexRegex = regexp.MustCompile(`^-?\d+$`) sliceRegex = regexp.MustCompile(`^((\-?\d+:\-?\d*)|(\-?\d*:\-?\d+))$`) diff --git a/token/delegation/examples_test.go b/token/delegation/examples_test.go index 89d1ee5..3766939 100644 --- a/token/delegation/examples_test.go +++ b/token/delegation/examples_test.go @@ -19,7 +19,6 @@ import ( "github.com/ucan-wg/go-ucan/pkg/command" "github.com/ucan-wg/go-ucan/pkg/policy" "github.com/ucan-wg/go-ucan/pkg/policy/literal" - "github.com/ucan-wg/go-ucan/pkg/policy/selector" "github.com/ucan-wg/go-ucan/token/delegation" "github.com/ucan-wg/go-ucan/token/internal/envelope" ) @@ -41,16 +40,16 @@ func ExampleNew() { cmd := command.MustParse("/foo/bar") // The policy defines what is allowed to do. - pol := policy.Policy{ - policy.Equal(selector.MustParse(".status"), literal.String("draft")), - policy.All(selector.MustParse(".reviewer"), - policy.MustLike(selector.MustParse(".email"), "*@example.com"), + pol := policy.MustConstruct( + policy.Equal(".status", literal.String("draft")), + policy.All(".reviewer", + policy.Like(".email", "*@example.com"), ), - policy.Any(selector.MustParse(".tags"), policy.Or( - policy.Equal(selector.Identity, literal.String("news")), - policy.Equal(selector.Identity, literal.String("press")), + policy.Any(".tags", policy.Or( + policy.Equal(".", literal.String("news")), + policy.Equal(".", literal.String("press")), )), - } + ) tkn, err := delegation.New(issPriv, audDid, cmd, pol, delegation.WithSubject(subDid), @@ -161,16 +160,16 @@ func ExampleRoot() { cmd := command.MustParse("/foo/bar") // The policy defines what is allowed to do. - pol := policy.Policy{ - policy.Equal(selector.MustParse(".status"), literal.String("draft")), - policy.All(selector.MustParse(".reviewer"), - policy.MustLike(selector.MustParse(".email"), "*@example.com"), + pol := policy.MustConstruct( + policy.Equal(".status", literal.String("draft")), + policy.All(".reviewer", + policy.Like(".email", "*@example.com"), ), - policy.Any(selector.MustParse(".tags"), policy.Or( - policy.Equal(selector.Identity, literal.String("news")), - policy.Equal(selector.Identity, literal.String("press")), + policy.Any(".tags", policy.Or( + policy.Equal(".", literal.String("news")), + policy.Equal(".", literal.String("press")), )), - } + ) tkn, err := delegation.Root(issPriv, audDid, cmd, pol, delegation.WithExpirationIn(time.Hour),