Files
crypto/salt/salt_test.go

422 lines
9.0 KiB
Go

package salt
import (
"bytes"
"fmt"
"testing"
)
func TestGenerate(t *testing.T) {
tests := []struct {
name string
size int
wantErr bool
}{
{"minimum size", MinSaltSize, false},
{"default size", DefaultSaltSize, false},
{"large size", 512, false},
{"maximum size", MaxSaltSize, false},
{"too small", MinSaltSize - 1, true},
{"too large", MaxSaltSize + 1, true},
{"zero size", 0, true},
{"negative size", -1, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
salt, err := Generate(tt.size)
if (err != nil) != tt.wantErr {
t.Errorf("Generate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if salt == nil {
t.Error("Generate() returned nil salt without error")
return
}
if salt.Size() != tt.size {
t.Errorf("Generate() size = %d, want %d", salt.Size(), tt.size)
}
if salt.IsEmpty() {
t.Error("Generate() returned empty salt")
}
// Test that salt contains random data (not all zeros)
saltBytes := salt.Bytes()
allZeros := true
for _, b := range saltBytes {
if b != 0 {
allZeros = false
break
}
}
if allZeros {
t.Error("Generate() returned all-zero salt (likely not random)")
}
}
})
}
}
func TestGenerateDefault(t *testing.T) {
salt, err := GenerateDefault()
if err != nil {
t.Fatalf("GenerateDefault() error = %v", err)
}
if salt.Size() != DefaultSaltSize {
t.Errorf("GenerateDefault() size = %d, want %d", salt.Size(), DefaultSaltSize)
}
}
func TestFromBytes(t *testing.T) {
validBytes := make([]byte, DefaultSaltSize)
for i := range validBytes {
validBytes[i] = byte(i)
}
tests := []struct {
name string
data []byte
wantErr bool
}{
{"valid default size", validBytes, false},
{"minimum size", make([]byte, MinSaltSize), false},
{"large size", make([]byte, 512), false},
{"too small", make([]byte, MinSaltSize-1), true},
{"too large", make([]byte, MaxSaltSize+1), true},
{"empty", []byte{}, true},
{"nil", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
salt, err := FromBytes(tt.data)
if (err != nil) != tt.wantErr {
t.Errorf("FromBytes() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if salt == nil {
t.Error("FromBytes() returned nil salt without error")
return
}
if salt.Size() != len(tt.data) {
t.Errorf("FromBytes() size = %d, want %d", salt.Size(), len(tt.data))
}
// Verify data is copied correctly
saltBytes := salt.Bytes()
if !bytes.Equal(saltBytes, tt.data) {
t.Error("FromBytes() data doesn't match input")
}
// Verify external modification doesn't affect salt
if len(tt.data) > 0 {
originalValue := tt.data[0]
tt.data[0] = ^tt.data[0] // Flip bits
if salt.Bytes()[0] != originalValue {
t.Error("FromBytes() salt was affected by external modification")
}
}
}
})
}
}
func TestSaltBytes(t *testing.T) {
salt, err := GenerateDefault()
if err != nil {
t.Fatalf("Failed to generate salt: %v", err)
}
bytes1 := salt.Bytes()
bytes2 := salt.Bytes()
// Should return same data
if !bytes.Equal(bytes1, bytes2) {
t.Error("Bytes() returned different data on multiple calls")
}
// Should be independent copies
if &bytes1[0] == &bytes2[0] {
t.Error("Bytes() returned same underlying array (not a copy)")
}
// Modifying returned bytes shouldn't affect salt
if len(bytes1) > 0 {
originalValue := bytes1[0]
bytes1[0] = ^bytes1[0]
bytes3 := salt.Bytes()
if bytes3[0] != originalValue {
t.Error("External modification of Bytes() affected salt")
}
}
}
func TestSaltEqual(t *testing.T) {
salt1, err := Generate(DefaultSaltSize)
if err != nil {
t.Fatalf("Failed to generate salt1: %v", err)
}
salt2, err := Generate(DefaultSaltSize)
if err != nil {
t.Fatalf("Failed to generate salt2: %v", err)
}
// Same salt data
salt3, err := FromBytes(salt1.Bytes())
if err != nil {
t.Fatalf("Failed to create salt3: %v", err)
}
tests := []struct {
name string
salt1 *Salt
salt2 *Salt
expected bool
}{
{"same salt", salt1, salt1, true},
{"equivalent salts", salt1, salt3, true},
{"different salts", salt1, salt2, false},
{"nil salts", nil, nil, true},
{"one nil salt", salt1, nil, false},
{"nil vs non-nil", nil, salt1, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.salt1.Equal(tt.salt2)
if result != tt.expected {
t.Errorf("Equal() = %v, want %v", result, tt.expected)
}
})
}
}
func TestSaltClear(t *testing.T) {
salt, err := GenerateDefault()
if err != nil {
t.Fatalf("Failed to generate salt: %v", err)
}
originalSize := salt.Size()
if originalSize == 0 {
t.Fatal("Salt size is zero before clear")
}
salt.Clear()
if salt.Size() != 0 {
t.Error("Salt size is not zero after clear")
}
if !salt.IsEmpty() {
t.Error("Salt is not empty after clear")
}
bytes := salt.Bytes()
if bytes != nil {
t.Error("Bytes() should return nil after clear")
}
}
func TestSaltStore(t *testing.T) {
store := NewSaltStore()
// Test empty store
if store.Size() != 0 {
t.Error("New store should be empty")
}
// Generate and store salt
salt1, err := GenerateDefault()
if err != nil {
t.Fatalf("Failed to generate salt: %v", err)
}
err = store.Store("test1", salt1)
if err != nil {
t.Fatalf("Failed to store salt: %v", err)
}
if store.Size() != 1 {
t.Errorf("Store size = %d, want 1", store.Size())
}
// Retrieve salt
retrieved, err := store.Retrieve("test1")
if err != nil {
t.Fatalf("Failed to retrieve salt: %v", err)
}
if !salt1.Equal(retrieved) {
t.Error("Retrieved salt doesn't match stored salt")
}
// Test generate and store
salt2, err := store.GenerateAndStore("test2", DefaultSaltSize)
if err != nil {
t.Fatalf("Failed to generate and store salt: %v", err)
}
if store.Size() != 2 {
t.Errorf("Store size = %d, want 2", store.Size())
}
if salt2.Size() != DefaultSaltSize {
t.Errorf("Generated salt size = %d, want %d", salt2.Size(), DefaultSaltSize)
}
// Test list
ids := store.List()
if len(ids) != 2 {
t.Errorf("List() returned %d ids, want 2", len(ids))
}
foundTest1 := false
foundTest2 := false
for _, id := range ids {
if id == "test1" {
foundTest1 = true
}
if id == "test2" {
foundTest2 = true
}
}
if !foundTest1 || !foundTest2 {
t.Error("List() doesn't contain expected IDs")
}
// Test remove
err = store.Remove("test1")
if err != nil {
t.Fatalf("Failed to remove salt: %v", err)
}
if store.Size() != 1 {
t.Errorf("Store size = %d, want 1 after removal", store.Size())
}
_, err = store.Retrieve("test1")
if err == nil {
t.Error("Should not be able to retrieve removed salt")
}
// Test clear
store.Clear()
if store.Size() != 0 {
t.Error("Store should be empty after clear")
}
}
func TestSaltStoreErrors(t *testing.T) {
store := NewSaltStore()
// Test empty identifier errors
err := store.Store("", nil)
if err == nil {
t.Error("Should error on empty identifier")
}
_, err = store.Retrieve("")
if err == nil {
t.Error("Should error on empty identifier")
}
err = store.Remove("")
if err == nil {
t.Error("Should error on empty identifier")
}
// Test nil salt error
err = store.Store("test", nil)
if err == nil {
t.Error("Should error on nil salt")
}
// Test empty salt error
emptySalt := &Salt{}
err = store.Store("test", emptySalt)
if err == nil {
t.Error("Should error on empty salt")
}
// Test retrieve non-existent
_, err = store.Retrieve("nonexistent")
if err == nil {
t.Error("Should error when retrieving non-existent salt")
}
// Test remove non-existent
err = store.Remove("nonexistent")
if err == nil {
t.Error("Should error when removing non-existent salt")
}
}
func TestConstantTimeCompare(t *testing.T) {
a := []byte{1, 2, 3, 4, 5}
b := []byte{1, 2, 3, 4, 5}
c := []byte{1, 2, 3, 4, 6}
d := []byte{1, 2, 3, 4}
tests := []struct {
name string
a, b []byte
expected bool
}{
{"equal slices", a, b, true},
{"different content", a, c, false},
{"different length", a, d, false},
{"empty slices", []byte{}, []byte{}, true},
{"one empty", a, []byte{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := constantTimeCompare(tt.a, tt.b)
if result != tt.expected {
t.Errorf("constantTimeCompare() = %v, want %v", result, tt.expected)
}
})
}
}
func BenchmarkGenerate(b *testing.B) {
sizes := []int{MinSaltSize, DefaultSaltSize, 512}
for _, size := range sizes {
b.Run(fmt.Sprintf("%dB", size), func(b *testing.B) {
for i := 0; i < b.N; i++ {
salt, err := Generate(size)
if err != nil {
b.Fatalf("Generate error: %v", err)
}
salt.Clear() // Clean up
}
})
}
}
func BenchmarkSaltEqual(b *testing.B) {
salt1, _ := GenerateDefault()
salt2, _ := GenerateDefault()
salt3, _ := FromBytes(salt1.Bytes())
b.Run("equal", func(b *testing.B) {
for i := 0; i < b.N; i++ {
salt1.Equal(salt3)
}
})
b.Run("different", func(b *testing.B) {
for i := 0; i < b.N; i++ {
salt1.Equal(salt2)
}
})
}