expose secretbox, notably for the GenerateKey() function that should be public

This commit is contained in:
Michael Muré
2024-12-12 16:04:31 +01:00
parent 47156a8ad6
commit 8bb3a4f4d0
3 changed files with 7 additions and 7 deletions

View File

@@ -0,0 +1,90 @@
package secretbox
import (
"crypto/rand"
"errors"
"fmt"
"io"
"golang.org/x/crypto/nacl/secretbox"
)
const keySize = 32 // secretbox allows only 32-byte keys
var ErrShortCipherText = errors.New("ciphertext too short")
var ErrNoEncryptionKey = errors.New("encryption key is required")
var ErrInvalidKeySize = errors.New("invalid key size: must be 32 bytes")
var ErrZeroKey = errors.New("encryption key cannot be all zeros")
// GenerateKey generates a random 32-byte key to be used by EncryptWithKey and DecryptWithKey
func GenerateKey() ([]byte, error) {
key := make([]byte, keySize)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return nil, fmt.Errorf("failed to generate key: %w", err)
}
return key, nil
}
// EncryptWithKey encrypts data using NaCl's secretbox with the provided key.
// 40 bytes of overhead (24-byte nonce + 16-byte MAC) are added to the plaintext size.
func EncryptWithKey(data, key []byte) ([]byte, error) {
if err := validateKey(key); err != nil {
return nil, err
}
var secretKey [keySize]byte
copy(secretKey[:], key)
// Generate 24 bytes of random data as nonce
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return nil, err
}
// Encrypt and authenticate data
encrypted := secretbox.Seal(nonce[:], data, &nonce, &secretKey)
return encrypted, nil
}
// DecryptStringWithKey decrypts data using secretbox with the provided key
func DecryptStringWithKey(data, key []byte) ([]byte, error) {
if err := validateKey(key); err != nil {
return nil, err
}
if len(data) < 24 {
return nil, ErrShortCipherText
}
var secretKey [keySize]byte
copy(secretKey[:], key)
var nonce [24]byte
copy(nonce[:], data[:24])
decrypted, ok := secretbox.Open(nil, data[24:], &nonce, &secretKey)
if !ok {
return nil, errors.New("decryption failed")
}
return decrypted, nil
}
func validateKey(key []byte) error {
if key == nil {
return ErrNoEncryptionKey
}
if len(key) != keySize {
return ErrInvalidKeySize
}
// check if key is all zeros
for _, b := range key {
if b != 0 {
return nil
}
}
return ErrZeroKey
}

View File

@@ -0,0 +1,144 @@
package secretbox
import (
"bytes"
"crypto/rand"
"testing"
"github.com/stretchr/testify/require"
)
func TestSecretBoxEncryption(t *testing.T) {
t.Parallel()
key := make([]byte, keySize) // generate random 32-byte key
_, errKey := rand.Read(key)
require.NoError(t, errKey)
tests := []struct {
name string
data []byte
key []byte
wantErr error
}{
{
name: "valid encryption/decryption",
data: []byte("hello world"),
key: key,
},
{
name: "nil key returns error",
data: []byte("hello world"),
key: nil,
wantErr: ErrNoEncryptionKey,
},
{
name: "empty data",
data: []byte{},
key: key,
},
{
name: "invalid key size",
data: []byte("hello world"),
key: make([]byte, 16), // Only 32 bytes allowed now
wantErr: ErrInvalidKeySize,
},
{
name: "zero key returns error",
data: []byte("hello world"),
key: make([]byte, keySize),
wantErr: ErrZeroKey,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
encrypted, err := EncryptWithKey(tt.data, tt.key)
if tt.wantErr != nil {
require.ErrorIs(t, err, tt.wantErr)
return
}
require.NoError(t, err)
// Verify encrypted data is different and includes nonce
require.Greater(t, len(encrypted), 24) // At least nonce size
if len(tt.data) > 0 {
require.NotEqual(t, tt.data, encrypted[24:]) // Ignore nonce prefix
}
decrypted, err := DecryptStringWithKey(encrypted, tt.key)
require.NoError(t, err)
require.True(t, bytes.Equal(tt.data, decrypted))
})
}
}
func TestDecryptionErrors(t *testing.T) {
t.Parallel()
key := make([]byte, keySize)
_, err := rand.Read(key)
require.NoError(t, err)
// Create valid encrypted data for tampering tests
validData := []byte("test message")
encrypted, err := EncryptWithKey(validData, key)
require.NoError(t, err)
tests := []struct {
name string
data []byte
key []byte
errMsg string
}{
{
name: "short ciphertext",
data: make([]byte, 23), // Less than nonce size
key: key,
errMsg: "ciphertext too short",
},
{
name: "invalid ciphertext",
data: make([]byte, 24), // Just nonce size
key: key,
errMsg: "decryption failed",
},
{
name: "tampered ciphertext",
data: tamperWithBytes(encrypted),
key: key,
errMsg: "decryption failed",
},
{
name: "missing key",
data: encrypted,
key: nil,
errMsg: "encryption key is required",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := DecryptStringWithKey(tt.data, tt.key)
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
})
}
}
// tamperWithBytes modifies a byte in the encrypted data to simulate tampering
func tamperWithBytes(data []byte) []byte {
if len(data) < 25 { // Need at least nonce + 1 byte
return data
}
tampered := make([]byte, len(data))
copy(tampered, data)
tampered[24] ^= 0x01 // Modify first byte after nonce
return tampered
}