validate non-zero aes key and other refactoring
This commit is contained in:
@@ -31,20 +31,22 @@ func (ks KeySize) IsValid() bool {
|
||||
var ErrShortCipherText = errors.New("ciphertext too short")
|
||||
var ErrNoEncryptionKey = errors.New("encryption key is required")
|
||||
var ErrInvalidKeySize = errors.New("invalid key size: must be 16, 24, or 32 bytes")
|
||||
var ErrZeroKey = errors.New("encryption key cannot be all zeros")
|
||||
|
||||
// NewKey generates a random AES key of the specified size.
|
||||
// If no size is provided, it defaults to KeySize256 (32 bytes).
|
||||
// GenerateKey generates a random AES key of default size KeySize256 (32 bytes).
|
||||
// Returns an error if the specified size is invalid or if key generation fails.
|
||||
func NewKey(size ...KeySize) ([]byte, error) {
|
||||
keySize := KeySize256
|
||||
if len(size) > 0 {
|
||||
keySize = size[0]
|
||||
if !keySize.IsValid() {
|
||||
func GenerateKey() ([]byte, error) {
|
||||
return GenerateKeyWithSize(KeySize256)
|
||||
}
|
||||
|
||||
// GenerateKeyWithSize generates a random AES key of the specified size.
|
||||
// Returns an error if the specified size is invalid or if key generation fails.
|
||||
func GenerateKeyWithSize(size KeySize) ([]byte, error) {
|
||||
if !size.IsValid() {
|
||||
return nil, ErrInvalidKeySize
|
||||
}
|
||||
}
|
||||
|
||||
key := make([]byte, keySize)
|
||||
key := make([]byte, size)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate AES key: %w", err)
|
||||
}
|
||||
@@ -119,5 +121,12 @@ func validateAESKey(key []byte) error {
|
||||
return ErrInvalidKeySize
|
||||
}
|
||||
|
||||
// check if key is all zeros
|
||||
for _, b := range key {
|
||||
if b != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrZeroKey
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package crypto
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -13,38 +12,42 @@ func TestAESEncryption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key := make([]byte, 32) // generated random 32-byte key
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
_, errKey := rand.Read(key)
|
||||
require.NoError(t, errKey)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
key []byte
|
||||
wantErr bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid encryption/decryption",
|
||||
data: []byte("hello world"),
|
||||
key: key,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil key returns error",
|
||||
data: []byte("hello world"),
|
||||
key: nil,
|
||||
wantErr: true,
|
||||
wantErr: ErrNoEncryptionKey,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
key: key,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid key size",
|
||||
data: []byte("hello world"),
|
||||
key: make([]byte, 31),
|
||||
wantErr: true,
|
||||
wantErr: ErrInvalidKeySize,
|
||||
},
|
||||
{
|
||||
name: "zero key returns error",
|
||||
data: []byte("hello world"),
|
||||
key: make([]byte, 32),
|
||||
wantErr: ErrZeroKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -54,14 +57,13 @@ func TestAESEncryption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encrypted, err := EncryptWithAESKey(tt.data, tt.key)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(encrypted))
|
||||
|
||||
decrypted, err := DecryptStringWithAESKey(encrypted, tt.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user