diff --git a/pkg/policy/match_test.go b/pkg/policy/match_test.go index 46a4a4f..df3459f 100644 --- a/pkg/policy/match_test.go +++ b/pkg/policy/match_test.go @@ -590,3 +590,80 @@ func TestOptionalSelectors(t *testing.T) { }) } } + +// The unique behaviour of PartialMatch is that it should return true for missing non-optional data (unlike Match). +func TestPartialMatch(t *testing.T) { + tests := []struct { + name string + policy Policy + data interface{} + expectedMatch bool + expectedStmt Statement + }{ + { + name: "returns true for missing non-optional field", + policy: MustConstruct( + Equal(".field", literal.String("value")), + ), + data: map[string]interface{}{}, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns true when present data matches", + policy: MustConstruct( + Equal(".foo", literal.String("correct")), + Equal(".missing", literal.String("whatever")), + ), + data: map[string]interface{}{ + "foo": "correct", + }, + expectedMatch: true, + expectedStmt: nil, + }, + { + name: "returns false with failing statement for present but non-matching value", + policy: MustConstruct( + Equal(".foo", literal.String("value1")), + Equal(".bar", literal.String("value2")), + ), + data: map[string]interface{}{ + "foo": "wrong", + "bar": "value2", + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Equal(".foo", literal.String("value1")), + )[0], + }, + { + name: "continues past missing data until finding actual mismatch", + policy: MustConstruct( + Equal(".missing", literal.String("value")), + Equal(".present", literal.String("wrong")), + ), + data: map[string]interface{}{ + "present": "actual", + }, + expectedMatch: false, + expectedStmt: MustConstruct( + Equal(".present", literal.String("wrong")), + )[0], + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node, err := literal.Map(tt.data) + assert.NoError(t, err) + + match, stmt := tt.policy.PartialMatch(node) + assert.Equal(t, tt.expectedMatch, match) + if tt.expectedStmt == nil { + assert.Nil(t, stmt) + } else { + assert.Equal(t, tt.expectedStmt.Kind(), stmt.Kind()) + } + }) + } +}