validate non-zero aes key and other refactoring

This commit is contained in:
Fabio Bozzo
2024-11-12 16:04:33 +01:00
parent 9f47418bdf
commit a26d836025
2 changed files with 40 additions and 29 deletions

View File

@@ -31,20 +31,22 @@ func (ks KeySize) IsValid() bool {
var ErrShortCipherText = errors.New("ciphertext too short")
var ErrNoEncryptionKey = errors.New("encryption key is required")
var ErrInvalidKeySize = errors.New("invalid key size: must be 16, 24, or 32 bytes")
var ErrZeroKey = errors.New("encryption key cannot be all zeros")
// NewKey generates a random AES key of the specified size.
// If no size is provided, it defaults to KeySize256 (32 bytes).
// GenerateKey generates a random AES key of default size KeySize256 (32 bytes).
// Returns an error if the specified size is invalid or if key generation fails.
func NewKey(size ...KeySize) ([]byte, error) {
keySize := KeySize256
if len(size) > 0 {
keySize = size[0]
if !keySize.IsValid() {
return nil, ErrInvalidKeySize
}
func GenerateKey() ([]byte, error) {
return GenerateKeyWithSize(KeySize256)
}
// GenerateKeyWithSize generates a random AES key of the specified size.
// Returns an error if the specified size is invalid or if key generation fails.
func GenerateKeyWithSize(size KeySize) ([]byte, error) {
if !size.IsValid() {
return nil, ErrInvalidKeySize
}
key := make([]byte, keySize)
key := make([]byte, size)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate AES key: %w", err)
}
@@ -119,5 +121,12 @@ func validateAESKey(key []byte) error {
return ErrInvalidKeySize
}
return nil
// check if key is all zeros
for _, b := range key {
if b != 0 {
return nil
}
}
return ErrZeroKey
}

View File

@@ -3,7 +3,6 @@ package crypto
import (
"bytes"
"crypto/rand"
"fmt"
"testing"
"github.com/stretchr/testify/require"
@@ -13,38 +12,42 @@ func TestAESEncryption(t *testing.T) {
t.Parallel()
key := make([]byte, 32) // generated random 32-byte key
_, err := rand.Read(key)
require.NoError(t, err)
_, errKey := rand.Read(key)
require.NoError(t, errKey)
tests := []struct {
name string
data []byte
key []byte
wantErr bool
wantErr error
}{
{
name: "valid encryption/decryption",
data: []byte("hello world"),
key: key,
wantErr: false,
name: "valid encryption/decryption",
data: []byte("hello world"),
key: key,
},
{
name: "nil key returns error",
data: []byte("hello world"),
key: nil,
wantErr: true,
wantErr: ErrNoEncryptionKey,
},
{
name: "empty data",
data: []byte{},
key: key,
wantErr: false,
name: "empty data",
data: []byte{},
key: key,
},
{
name: "invalid key size",
data: []byte("hello world"),
key: make([]byte, 31),
wantErr: true,
wantErr: ErrInvalidKeySize,
},
{
name: "zero key returns error",
data: []byte("hello world"),
key: make([]byte, 32),
wantErr: ErrZeroKey,
},
}
@@ -54,14 +57,13 @@ func TestAESEncryption(t *testing.T) {
t.Parallel()
encrypted, err := EncryptWithAESKey(tt.data, tt.key)
if tt.wantErr {
require.Error(t, err)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
fmt.Println(string(encrypted))
decrypted, err := DecryptStringWithAESKey(encrypted, tt.key)
require.NoError(t, err)