refactor meta/internal/crypto and add key generation method

This commit is contained in:
Fabio Bozzo
2024-11-12 15:29:48 +01:00
parent 7cb0f97b30
commit 3987e8649c
6 changed files with 142 additions and 64 deletions

View File

@@ -0,0 +1,123 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
)
// KeySize represents valid AES key sizes
type KeySize int
const (
KeySize128 KeySize = 16 // AES-128
KeySize192 KeySize = 24 // AES-192
KeySize256 KeySize = 32 // AES-256 (recommended)
)
// IsValid returns true if the key size is valid for AES
func (ks KeySize) IsValid() bool {
switch ks {
case KeySize128, KeySize192, KeySize256:
return true
default:
return false
}
}
var ErrShortCipherText = errors.New("ciphertext too short")
var ErrNoEncryptionKey = errors.New("encryption key is required")
var ErrInvalidKeySize = errors.New("invalid key size: must be 16, 24, or 32 bytes")
// NewKey generates a random AES key of the specified size.
// If no size is provided, it defaults to KeySize256 (32 bytes).
// Returns an error if the specified size is invalid or if key generation fails.
func NewKey(size ...KeySize) ([]byte, error) {
keySize := KeySize256
if len(size) > 0 {
keySize = size[0]
if !keySize.IsValid() {
return nil, ErrInvalidKeySize
}
}
key := make([]byte, keySize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate AES key: %w", err)
}
return key, nil
}
// EncryptWithAESKey encrypts data using AES-GCM with the provided key.
// The key must be 16, 24, or 32 bytes long (for AES-128, AES-192, or AES-256).
// Returns the encrypted data with the nonce prepended, or an error if encryption fails.
func EncryptWithAESKey(data, key []byte) ([]byte, error) {
if err := validateAESKey(key); err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, data, nil), nil
}
// DecryptStringWithAESKey decrypts data that was encrypted with EncryptWithAESKey.
// The key must match the one used for encryption.
// Expects the input to have a prepended nonce.
// Returns the decrypted data or an error if decryption fails.
func DecryptStringWithAESKey(data, key []byte) ([]byte, error) {
if err := validateAESKey(key); err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(data) < gcm.NonceSize() {
return nil, ErrShortCipherText
}
nonce, ciphertext := data[:gcm.NonceSize()], data[gcm.NonceSize():]
decrypted, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return decrypted, nil
}
func validateAESKey(key []byte) error {
if key == nil {
return ErrNoEncryptionKey
}
if !KeySize(len(key)).IsValid() {
return ErrInvalidKeySize
}
return nil
}

View File

@@ -0,0 +1,122 @@
package crypto
import (
"bytes"
"crypto/rand"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestAESEncryption(t *testing.T) {
t.Parallel()
key := make([]byte, 32) // generated random 32-byte key
_, err := rand.Read(key)
require.NoError(t, err)
tests := []struct {
name string
data []byte
key []byte
wantErr bool
}{
{
name: "valid encryption/decryption",
data: []byte("hello world"),
key: key,
wantErr: false,
},
{
name: "nil key returns error",
data: []byte("hello world"),
key: nil,
wantErr: true,
},
{
name: "empty data",
data: []byte{},
key: key,
wantErr: false,
},
{
name: "invalid key size",
data: []byte("hello world"),
key: make([]byte, 31),
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
encrypted, err := EncryptWithAESKey(tt.data, tt.key)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
fmt.Println(string(encrypted))
decrypted, err := DecryptStringWithAESKey(encrypted, tt.key)
require.NoError(t, err)
if tt.key == nil {
require.Equal(t, tt.data, encrypted)
require.Equal(t, tt.data, decrypted)
} else {
require.NotEqual(t, tt.data, encrypted)
require.True(t, bytes.Equal(tt.data, decrypted))
}
})
}
}
func TestDecryptionErrors(t *testing.T) {
t.Parallel()
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
tests := []struct {
name string
data []byte
key []byte
errMsg string
}{
{
name: "short ciphertext",
data: []byte("short"),
key: key,
errMsg: "ciphertext too short",
},
{
name: "invalid ciphertext",
data: make([]byte, 16), // just nonce size
key: key,
errMsg: "message authentication failed",
},
{
name: "missing key",
data: []byte("<22>`M<><4D><EFBFBD>l\u001AIF<49>\u0012<31><32><EFBFBD>=h<>?<3F>c<EFBFBD> <20><>\u0012<31><32><EFBFBD><EFBFBD>\u001C<31>\u0018Ƽ(g"),
key: nil,
errMsg: "encryption key is required",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := DecryptStringWithAESKey(tt.data, tt.key)
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
})
}
}

View File

@@ -1,6 +1,7 @@
package meta
import (
"errors"
"fmt"
"reflect"
"strings"
@@ -10,7 +11,7 @@ import (
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/ipld/go-ipld-prime/printer"
"github.com/ucan-wg/go-ucan/pkg/crypto"
"github.com/ucan-wg/go-ucan/pkg/meta/internal/crypto"
)
var ErrUnsupported = errors.New("failure adding unsupported type to meta")

View File

@@ -17,6 +17,10 @@ func (r ReadOnly) GetString(key string) (string, error) {
return r.m.GetString(key)
}
func (r ReadOnly) GetEncryptedString(key string, encryptionKey []byte) (string, error) {
return r.m.GetEncryptedString(key, encryptionKey)
}
func (r ReadOnly) GetInt64(key string) (int64, error) {
return r.m.GetInt64(key)
}
@@ -29,6 +33,10 @@ func (r ReadOnly) GetBytes(key string) ([]byte, error) {
return r.m.GetBytes(key)
}
func (r ReadOnly) GetEncryptedBytes(key string, encryptionKey []byte) ([]byte, error) {
return r.m.GetEncryptedBytes(key, encryptionKey)
}
func (r ReadOnly) GetNode(key string) (ipld.Node, error) {
return r.m.GetNode(key)
}