From 72f4ef7b5eb3d42ca51a86899e9b5f5b51f7541e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Mur=C3=A9?= Date: Mon, 4 Nov 2024 19:07:36 +0100 Subject: [PATCH] policy: fix incorrect test for PartialMatch --- pkg/policy/match.go | 4 +--- pkg/policy/match_test.go | 17 ++++++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/pkg/policy/match.go b/pkg/policy/match.go index 2a586f5..c3862d5 100644 --- a/pkg/policy/match.go +++ b/pkg/policy/match.go @@ -242,9 +242,7 @@ func matchStatement(cur Statement, node ipld.Node) (_ matchResult, leafMost Stat // continue } } - - // when no elements match, return the leaf statement instead of 'cur' - return matchResultFalse, s.statement + return matchResultFalse, cur } } panic(fmt.Errorf("unimplemented statement kind: %s", cur.Kind())) diff --git a/pkg/policy/match_test.go b/pkg/policy/match_test.go index 56a2814..7d10d43 100644 --- a/pkg/policy/match_test.go +++ b/pkg/policy/match_test.go @@ -9,7 +9,6 @@ import ( "github.com/ipld/go-ipld-prime/codec/dagjson" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ucan-wg/go-ucan/pkg/policy/literal" @@ -581,12 +580,12 @@ func TestOptionalSelectors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { nb := basicnode.Prototype.Map.NewBuilder() n, err := literal.Map(tt.data) - assert.NoError(t, err) + require.NoError(t, err) err = nb.AssignNode(n) - assert.NoError(t, err) + require.NoError(t, err) result := tt.policy.Match(nb.Build()) - assert.Equal(t, tt.expected, result) + require.Equal(t, tt.expected, result) }) } } @@ -738,7 +737,7 @@ func TestPartialMatch(t *testing.T) { }, expectedMatch: false, expectedStmt: MustConstruct( - Equal(".", literal.Int(4)), + Any(".numbers", Equal(".", literal.Int(4))), )[0], }, @@ -878,14 +877,14 @@ func TestPartialMatch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { node, err := literal.Map(tt.data) - assert.NoError(t, err) + require.NoError(t, err) match, stmt := tt.policy.PartialMatch(node) - assert.Equal(t, tt.expectedMatch, match) + require.Equal(t, tt.expectedMatch, match) if tt.expectedStmt == nil { - assert.Nil(t, stmt) + require.Nil(t, stmt) } else { - assert.Equal(t, tt.expectedStmt.Kind(), stmt.Kind()) + require.Equal(t, tt.expectedStmt, stmt) } }) }