diff --git a/pkg/meta/internal/crypto/aes.go b/pkg/meta/internal/crypto/aes.go index 28f0ae4..482402e 100644 --- a/pkg/meta/internal/crypto/aes.go +++ b/pkg/meta/internal/crypto/aes.go @@ -31,20 +31,22 @@ func (ks KeySize) IsValid() bool { var ErrShortCipherText = errors.New("ciphertext too short") var ErrNoEncryptionKey = errors.New("encryption key is required") var ErrInvalidKeySize = errors.New("invalid key size: must be 16, 24, or 32 bytes") +var ErrZeroKey = errors.New("encryption key cannot be all zeros") -// NewKey generates a random AES key of the specified size. -// If no size is provided, it defaults to KeySize256 (32 bytes). +// GenerateKey generates a random AES key of default size KeySize256 (32 bytes). // Returns an error if the specified size is invalid or if key generation fails. -func NewKey(size ...KeySize) ([]byte, error) { - keySize := KeySize256 - if len(size) > 0 { - keySize = size[0] - if !keySize.IsValid() { - return nil, ErrInvalidKeySize - } +func GenerateKey() ([]byte, error) { + return GenerateKeyWithSize(KeySize256) +} + +// GenerateKeyWithSize generates a random AES key of the specified size. +// Returns an error if the specified size is invalid or if key generation fails. +func GenerateKeyWithSize(size KeySize) ([]byte, error) { + if !size.IsValid() { + return nil, ErrInvalidKeySize } - key := make([]byte, keySize) + key := make([]byte, size) if _, err := io.ReadFull(rand.Reader, key); err != nil { return nil, fmt.Errorf("failed to generate AES key: %w", err) } @@ -119,5 +121,12 @@ func validateAESKey(key []byte) error { return ErrInvalidKeySize } - return nil + // check if key is all zeros + for _, b := range key { + if b != 0 { + return nil + } + } + + return ErrZeroKey } diff --git a/pkg/meta/internal/crypto/aes_test.go b/pkg/meta/internal/crypto/aes_test.go index 3462a10..1d0d3e4 100644 --- a/pkg/meta/internal/crypto/aes_test.go +++ b/pkg/meta/internal/crypto/aes_test.go @@ -3,7 +3,6 @@ package crypto import ( "bytes" "crypto/rand" - "fmt" "testing" "github.com/stretchr/testify/require" @@ -13,38 +12,42 @@ func TestAESEncryption(t *testing.T) { t.Parallel() key := make([]byte, 32) // generated random 32-byte key - _, err := rand.Read(key) - require.NoError(t, err) + _, errKey := rand.Read(key) + require.NoError(t, errKey) tests := []struct { name string data []byte key []byte - wantErr bool + wantErr error }{ { - name: "valid encryption/decryption", - data: []byte("hello world"), - key: key, - wantErr: false, + name: "valid encryption/decryption", + data: []byte("hello world"), + key: key, }, { name: "nil key returns error", data: []byte("hello world"), key: nil, - wantErr: true, + wantErr: ErrNoEncryptionKey, }, { - name: "empty data", - data: []byte{}, - key: key, - wantErr: false, + name: "empty data", + data: []byte{}, + key: key, }, { name: "invalid key size", data: []byte("hello world"), key: make([]byte, 31), - wantErr: true, + wantErr: ErrInvalidKeySize, + }, + { + name: "zero key returns error", + data: []byte("hello world"), + key: make([]byte, 32), + wantErr: ErrZeroKey, }, } @@ -54,14 +57,13 @@ func TestAESEncryption(t *testing.T) { t.Parallel() encrypted, err := EncryptWithAESKey(tt.data, tt.key) - if tt.wantErr { - require.Error(t, err) + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + return } require.NoError(t, err) - fmt.Println(string(encrypted)) - decrypted, err := DecryptStringWithAESKey(encrypted, tt.key) require.NoError(t, err)