Files
common/webauthn/authenticator_test.go

624 lines
15 KiB
Go
Raw Permalink Normal View History

2025-10-10 10:17:22 -04:00
package webauthn
import (
"encoding/base64"
"encoding/binary"
"encoding/hex"
"fmt"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
noneAuthDataBase64 = "pkLSG3xtVeHOI8U5mCjSx0m/am7y/gPMnhDN9O1ttItBAAAAAAAAAAAAAAAAAAAAAAAAAAAAQMAxl6G32ykWaLrv/ouCs5HoGsvONqBtOb7ZmyMs8K8PccnwyyqPzWn/yZuyQmQBguvjYSvH6gDBlFG65quUDCSlAQIDJiABIVggyJGP+ra/u/eVjqN4OeYXUShRWxrEeC6Sb5/bZmJ9q8MiWCCHIkRdg5oRb1RHoFVYUpogcjlObCKFsV1ls1T+uUc6rA=="
attAuthDataBase64 = "lWkIjx7O4yMpVANdvRDXyuORMFonUbVZu4/Xy7IpvdRBAAAAAAAAAAAAAAAAAAAAAAAAAAAAQIniszxcGnhupdPFOHJIm6dscrWCC2h8xHicBMu91THD0kdOdB0QQtkaEn+6KfsfT1o3NmmFT8YfXrG734WfVSmlAQIDJiABIVggyoHHeiUw5aSbt8/GsL9zaqZGRzV26A4y3CnCGUhVXu4iWCBMnc8za5xgPzIygngAv9W+vZTMGJwwZcM4sjiqkcb/1g=="
)
func TestAuthenticatorFlags_UserPresent(t *testing.T) {
var (
goodByte byte = 0x01
badByte byte = 0x10
)
tests := []struct {
name string
flag AuthenticatorFlags
want bool
}{
{
"Present",
AuthenticatorFlags(goodByte),
true,
},
{
"Missing",
AuthenticatorFlags(badByte),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.flag.UserPresent(); got != tt.want {
t.Errorf("AuthenticatorFlags.UserPresent() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthenticatorFlags_UserVerified(t *testing.T) {
var (
goodByte byte = 0x04
badByte byte = 0x02
)
tests := []struct {
name string
flag AuthenticatorFlags
want bool
}{
{
"Present",
AuthenticatorFlags(goodByte),
true,
},
{
"Missing",
AuthenticatorFlags(badByte),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.flag.UserVerified(); got != tt.want {
t.Errorf("AuthenticatorFlags.UserVerified() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthenticatorFlags_HasAttestedCredentialData(t *testing.T) {
var (
goodByte byte = 0x40
badByte byte = 0x01
)
tests := []struct {
name string
flag AuthenticatorFlags
want bool
}{
{
"Present",
AuthenticatorFlags(goodByte),
true,
},
{
"Missing",
AuthenticatorFlags(badByte),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.flag.HasAttestedCredentialData(); got != tt.want {
t.Errorf(
"AuthenticatorFlags.HasAttestedCredentialData() = %v, want %v",
got,
tt.want,
)
}
})
}
}
func TestAuthenticatorFlags_HasExtensions(t *testing.T) {
var (
goodByte byte = 0x80
badByte byte = 0x01
)
tests := []struct {
name string
flag AuthenticatorFlags
want bool
}{
{
"Present",
AuthenticatorFlags(goodByte),
true,
},
{
"Missing",
AuthenticatorFlags(badByte),
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.flag.HasExtensions(); got != tt.want {
t.Errorf("AuthenticatorFlags.HasExtensions() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthenticatorData_Unmarshal(t *testing.T) {
type fields struct {
RPIDHash []byte
Flags AuthenticatorFlags
Counter uint32
AttData AttestedCredentialData
ExtData []byte
}
type args struct {
rawAuthData []byte
}
noneAuthData, _ := base64.StdEncoding.DecodeString(noneAuthDataBase64)
attAuthData, _ := base64.StdEncoding.DecodeString(attAuthDataBase64)
// Empty data
badAuthData1 := []byte{}
// Attested credential data missing
badAuthData2 := make([]byte, minAttestedAuthLength-1)
copy(badAuthData2, attAuthData)
// Flags not set but data exists
badAuthData3 := make([]byte, len(attAuthData))
copy(badAuthData3, attAuthData)
badAuthData3[32] &= 0b0011_1111
// Extensions data missing
badAuthData4 := make([]byte, len(attAuthData))
copy(badAuthData4, attAuthData)
badAuthData4[32] |= 0b1000_0000
// Leftover bytes
badAuthData5 := make([]byte, len(attAuthData))
copy(badAuthData5, attAuthData)
badAuthData5 = append(badAuthData5, []byte("Hello World")...)
tests := []struct {
name string
fields fields
args args
errString string
errType string
errDetails string
errInfo string
}{
{
name: "None Marshall Successfully",
fields: fields{},
args: args{
noneAuthData,
},
},
{
name: "Att Data Marshall Successfully",
fields: fields{},
args: args{
attAuthData,
},
},
{
name: "Authenticator data too short",
fields: fields{},
args: args{
badAuthData1,
},
errString: "Authenticator data length too short",
errType: "invalid_request",
errDetails: "Authenticator data length too short",
errInfo: fmt.Sprintf(
"Expected data greater than %d bytes. Got %d bytes",
minAuthDataLength,
len(badAuthData1),
),
},
{
name: "Attested credential missing",
fields: fields{},
args: args{
badAuthData2,
},
errString: "Attested credential flag set but data is missing",
errType: "invalid_request",
errDetails: "Attested credential flag set but data is missing",
errInfo: "",
},
{
name: "Attested credential missing",
fields: fields{},
args: args{
badAuthData3,
},
errString: "Attested credential flag not set",
errType: "invalid_request",
errDetails: "Attested credential flag not set",
errInfo: "",
},
{
name: "Extensions data missing",
fields: fields{},
args: args{
badAuthData4,
},
errString: "Extensions flag set but extensions data is missing",
errType: "invalid_request",
errDetails: "Extensions flag set but extensions data is missing",
errInfo: "",
},
{
name: "Leftover bytes",
fields: fields{},
args: args{
badAuthData5,
},
errString: "Leftover bytes decoding AuthenticatorData",
errType: "invalid_request",
errDetails: "Leftover bytes decoding AuthenticatorData",
errInfo: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &AuthenticatorData{
RPIDHash: tt.fields.RPIDHash,
Flags: tt.fields.Flags,
Counter: tt.fields.Counter,
AttData: tt.fields.AttData,
ExtData: tt.fields.ExtData,
}
err := a.Unmarshal(tt.args.rawAuthData)
if tt.errString != "" {
assert.EqualError(t, err, tt.errString)
AssertIsProtocolError(t, err, tt.errType, tt.errDetails, tt.errInfo)
return
}
require.NoError(t, err)
})
}
}
func TestAuthenticatorData_unmarshalAttestedData(t *testing.T) {
type fields struct {
RPIDHash []byte
Flags AuthenticatorFlags
Counter uint32
AttData AttestedCredentialData
ExtData []byte
}
type args struct {
rawAuthData []byte
}
noneAuthData, _ := base64.StdEncoding.DecodeString(noneAuthDataBase64)
attAuthData, _ := base64.StdEncoding.DecodeString(attAuthDataBase64)
// Data length too short
badAuthData1 := make([]byte, len(attAuthData))
copy(badAuthData1, attAuthData)
binary.BigEndian.PutUint16(badAuthData1[53:], 256)
// ID length too long
badAuthData2 := make([]byte, len(attAuthData)+maxCredentialIDLength+1)
copy(badAuthData2, attAuthData)
binary.BigEndian.PutUint16(badAuthData2[53:], maxCredentialIDLength+1)
// Malformed public key
badAuthData3 := make([]byte, 119)
copy(badAuthData3, attAuthData[:119])
badData, _ := hex.DecodeString("83FF20030102")
badAuthData3 = append(badAuthData3, badData...)
tests := []struct {
name string
fields fields
args args
errString string
errType string
errDetails string
errInfo string
}{
{
name: "None Marshall Successfully",
fields: fields{},
args: args{
noneAuthData,
},
},
{
name: "Att Data Marshall Successfully",
fields: fields{},
args: args{
attAuthData,
},
},
{
name: "Data length too short",
fields: fields{},
args: args{
badAuthData1,
},
errString: "Authenticator attestation data length too short",
errType: "invalid_request",
errDetails: "Authenticator attestation data length too short",
errInfo: "",
},
{
name: "ID length too long",
fields: fields{},
args: args{
badAuthData2,
},
errString: "Authenticator attestation data credential id length too long",
errType: "invalid_request",
errDetails: "Authenticator attestation data credential id length too long",
errInfo: "",
},
{
name: "Could not unmarshal Credential Public Key",
fields: fields{},
args: args{
badAuthData3,
},
errString: "Could not unmarshal Credential Public Key: cbor: unexpected \"break\" code",
errType: "invalid_request",
errDetails: "Could not unmarshal Credential Public Key: cbor: unexpected \"break\" code",
errInfo: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &AuthenticatorData{
RPIDHash: tt.fields.RPIDHash,
Flags: tt.fields.Flags,
Counter: tt.fields.Counter,
AttData: tt.fields.AttData,
ExtData: tt.fields.ExtData,
}
err := a.unmarshalAttestedData(tt.args.rawAuthData)
if tt.errString != "" {
assert.EqualError(t, err, tt.errString)
AssertIsProtocolError(t, err, tt.errType, tt.errDetails, tt.errInfo)
return
}
require.NoError(t, err)
})
}
}
func Test_unmarshalCredentialPublicKey(t *testing.T) {
type args struct {
keyBytes []byte
}
tests := []struct {
name string
args args
want []byte
}{
{
name: "Valid EC2 P-256 public key",
args: args{
keyBytes: []byte{
0xa5, // map(5)
0x01, 0x02, // kty: EC2 (2)
0x03, 0x26, // alg: ES256 (-7)
0x20, 0x01, // crv: P-256 (1)
0x21, 0x58, 0x20, // x coordinate (32 bytes)
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x22, 0x58, 0x20, // y coordinate (32 bytes)
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38,
0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40,
},
},
want: []byte{
0xa5, // map(5)
0x01, 0x02, // kty: EC2 (2)
0x03, 0x26, // alg: ES256 (-7)
0x20, 0x01, // crv: P-256 (1)
0x21, 0x58, 0x20, // x coordinate (32 bytes)
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x22, 0x58, 0x20, // y coordinate (32 bytes)
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38,
0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40,
},
},
{
name: "Valid RSA public key",
args: args{
keyBytes: append([]byte{
0xa4, // map(4)
0x01, 0x03, // kty: RSA (3)
0x03, 0x39, 0x01, 0x00, // alg: RS256 (-257)
0x20, 0x59, 0x01, 0x00, // n: modulus (256 bytes)
}, append(make([]byte, 256), []byte{
0x21, 0x43, // e: exponent (3 bytes)
0x01, 0x00, 0x01, // 65537
}...)...),
},
want: append([]byte{
0xa4, // map(4)
0x01, 0x03, // kty: RSA (3)
0x03, 0x39, 0x01, 0x00, // alg: RS256 (-257)
0x20, 0x59, 0x01, 0x00, // n: modulus (256 bytes)
}, append(make([]byte, 256), []byte{
0x21, 0x43, // e: exponent (3 bytes)
0x01, 0x00, 0x01, // 65537
}...)...),
},
{
name: "Valid Ed25519 public key",
args: args{
keyBytes: []byte{
0xa4, // map(4)
0x01, 0x01, // kty: OKP (1)
0x03, 0x27, // alg: EdDSA (-8)
0x20, 0x06, // crv: Ed25519 (6)
0x21, 0x58, 0x20, // x coordinate (32 bytes)
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
},
},
want: []byte{
0xa4, // map(4)
0x01, 0x01, // kty: OKP (1)
0x03, 0x27, // alg: EdDSA (-8)
0x20, 0x06, // crv: Ed25519 (6)
0x21, 0x58, 0x20, // x coordinate (32 bytes)
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := unmarshalCredentialPublicKey(tt.args.keyBytes)
if err != nil {
t.Errorf("unmarshalCredentialPublicKey() returned err %v", err)
} else if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalCredentialPublicKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthenticatorData_Verify(t *testing.T) {
type fields struct {
RPIDHash []byte
Flags AuthenticatorFlags
Counter uint32
AttData AttestedCredentialData
ExtData []byte
}
type args struct {
rpIdHash []byte
userVerificationRequired bool
userPresenceRequired bool
}
tests := []struct {
name string
fields fields
args args
errString string
errType string
errDetails string
errInfo string
}{
{
name: "Success",
fields: fields{
RPIDHash: []byte{1, 2, 3},
Flags: AuthenticatorFlags(0x05),
},
args: args{
rpIdHash: []byte{1, 2, 3},
},
errString: "",
},
{
name: "RP hash mismatch",
fields: fields{
RPIDHash: []byte{0xff},
},
args: args{
rpIdHash: []byte{0xaa},
},
errString: "Error validating the authenticator response",
errType: "verification_error",
errDetails: "Error validating the authenticator response",
errInfo: "RP Hash mismatch. Expected ff and Received aa",
},
{
name: "UP flag not set",
fields: fields{
RPIDHash: []byte{1, 2, 3},
Flags: AuthenticatorFlags(0x04),
},
args: args{
rpIdHash: []byte{1, 2, 3},
userPresenceRequired: true,
},
errString: "Error validating the authenticator response",
errType: "verification_error",
errDetails: "Error validating the authenticator response",
errInfo: "User presence required but flag not set by authenticator",
},
{
name: "User verification required",
fields: fields{
RPIDHash: []byte{1, 2, 3},
Flags: AuthenticatorFlags(0x01),
},
args: args{
rpIdHash: []byte{1, 2, 3},
userVerificationRequired: true,
userPresenceRequired: true,
},
errString: "Error validating the authenticator response",
errType: "verification_error",
errDetails: "Error validating the authenticator response",
errInfo: "User verification required but flag not set by authenticator",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &AuthenticatorData{
RPIDHash: tt.fields.RPIDHash,
Flags: tt.fields.Flags,
Counter: tt.fields.Counter,
AttData: tt.fields.AttData,
ExtData: tt.fields.ExtData,
}
err := a.Verify(
tt.args.rpIdHash,
nil,
tt.args.userVerificationRequired,
tt.args.userPresenceRequired,
)
if tt.errString != "" {
assert.EqualError(t, err, tt.errString)
AssertIsProtocolError(t, err, tt.errType, tt.errDetails, tt.errInfo)
return
}
require.NoError(t, err)
})
}
}