diff --git a/pkg/args/args.go b/pkg/args/args.go index 3cee3ce..a840c6e 100644 --- a/pkg/args/args.go +++ b/pkg/args/args.go @@ -16,6 +16,7 @@ import ( "github.com/ipld/go-ipld-prime/node/basicnode" "github.com/ipld/go-ipld-prime/printer" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" "github.com/ucan-wg/go-ucan/pkg/policy/literal" ) @@ -62,6 +63,10 @@ func (a *Args) Add(key string, val any) error { return err } + if err := limits.ValidateIntegerBoundsIPLD(node); err != nil { + return fmt.Errorf("value for key %q: %w", key, err) + } + a.Values[key] = node a.Keys = append(a.Keys, key) @@ -164,3 +169,14 @@ func (a *Args) Clone() *Args { } return res } + +// Validate checks that all values in the Args are valid according to UCAN specs +func (a *Args) Validate() error { + for key, value := range a.Values { + if err := limits.ValidateIntegerBoundsIPLD(value); err != nil { + return fmt.Errorf("value for key %q: %w", key, err) + } + } + + return nil +} diff --git a/pkg/args/args_test.go b/pkg/args/args_test.go index 2a44d0f..938151f 100644 --- a/pkg/args/args_test.go +++ b/pkg/args/args_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ucan-wg/go-ucan/pkg/args" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" "github.com/ucan-wg/go-ucan/pkg/policy/literal" ) @@ -185,6 +186,71 @@ func TestInclude(t *testing.T) { }, maps.Collect(a1.Iter())) } +func TestArgsIntegerBounds(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + key string + val int64 + wantErr string + }{ + { + name: "valid int", + key: "valid", + val: 42, + }, + { + name: "max safe integer", + key: "max", + val: limits.MaxInt53, + }, + { + name: "min safe integer", + key: "min", + val: limits.MinInt53, + }, + { + name: "exceeds max safe integer", + key: "tooBig", + val: limits.MaxInt53 + 1, + wantErr: "exceeds safe integer bounds", + }, + { + name: "below min safe integer", + key: "tooSmall", + val: limits.MinInt53 - 1, + wantErr: "exceeds safe integer bounds", + }, + { + name: "duplicate key", + key: "duplicate", + val: 42, + wantErr: "duplicate key", + }, + } + + a := args.New() + require.NoError(t, a.Add("duplicate", 1)) // tests duplicate key + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := a.Add(tt.key, tt.val) + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + val, err := a.GetNode(tt.key) + require.NoError(t, err) + i, err := val.AsInt() + require.NoError(t, err) + require.Equal(t, tt.val, i) + } + }) + } +} + const ( argsSchema = "type Args { String : Any }" argsName = "Args" diff --git a/pkg/meta/internal/crypto/aes.go b/pkg/meta/internal/crypto/aes.go deleted file mode 100644 index 482402e..0000000 --- a/pkg/meta/internal/crypto/aes.go +++ /dev/null @@ -1,132 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "errors" - "fmt" - "io" -) - -// KeySize represents valid AES key sizes -type KeySize int - -const ( - KeySize128 KeySize = 16 // AES-128 - KeySize192 KeySize = 24 // AES-192 - KeySize256 KeySize = 32 // AES-256 (recommended) -) - -// IsValid returns true if the key size is valid for AES -func (ks KeySize) IsValid() bool { - switch ks { - case KeySize128, KeySize192, KeySize256: - return true - default: - return false - } -} - -var ErrShortCipherText = errors.New("ciphertext too short") -var ErrNoEncryptionKey = errors.New("encryption key is required") -var ErrInvalidKeySize = errors.New("invalid key size: must be 16, 24, or 32 bytes") -var ErrZeroKey = errors.New("encryption key cannot be all zeros") - -// GenerateKey generates a random AES key of default size KeySize256 (32 bytes). -// Returns an error if the specified size is invalid or if key generation fails. -func GenerateKey() ([]byte, error) { - return GenerateKeyWithSize(KeySize256) -} - -// GenerateKeyWithSize generates a random AES key of the specified size. -// Returns an error if the specified size is invalid or if key generation fails. -func GenerateKeyWithSize(size KeySize) ([]byte, error) { - if !size.IsValid() { - return nil, ErrInvalidKeySize - } - - key := make([]byte, size) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - return nil, fmt.Errorf("failed to generate AES key: %w", err) - } - - return key, nil -} - -// EncryptWithAESKey encrypts data using AES-GCM with the provided key. -// The key must be 16, 24, or 32 bytes long (for AES-128, AES-192, or AES-256). -// Returns the encrypted data with the nonce prepended, or an error if encryption fails. -func EncryptWithAESKey(data, key []byte) ([]byte, error) { - if err := validateAESKey(key); err != nil { - return nil, err - } - - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - return nil, err - } - - return gcm.Seal(nonce, nonce, data, nil), nil -} - -// DecryptStringWithAESKey decrypts data that was encrypted with EncryptWithAESKey. -// The key must match the one used for encryption. -// Expects the input to have a prepended nonce. -// Returns the decrypted data or an error if decryption fails. -func DecryptStringWithAESKey(data, key []byte) ([]byte, error) { - if err := validateAESKey(key); err != nil { - return nil, err - } - - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - if len(data) < gcm.NonceSize() { - return nil, ErrShortCipherText - } - - nonce, ciphertext := data[:gcm.NonceSize()], data[gcm.NonceSize():] - decrypted, err := gcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return nil, err - } - - return decrypted, nil -} - -func validateAESKey(key []byte) error { - if key == nil { - return ErrNoEncryptionKey - } - - if !KeySize(len(key)).IsValid() { - return ErrInvalidKeySize - } - - // check if key is all zeros - for _, b := range key { - if b != 0 { - return nil - } - } - - return ErrZeroKey -} diff --git a/pkg/meta/internal/crypto/secretbox.go b/pkg/meta/internal/crypto/secretbox.go new file mode 100644 index 0000000..690be7e --- /dev/null +++ b/pkg/meta/internal/crypto/secretbox.go @@ -0,0 +1,90 @@ +package crypto + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/nacl/secretbox" +) + +const keySize = 32 // secretbox allows only 32-byte keys + +var ErrShortCipherText = errors.New("ciphertext too short") +var ErrNoEncryptionKey = errors.New("encryption key is required") +var ErrInvalidKeySize = errors.New("invalid key size: must be 32 bytes") +var ErrZeroKey = errors.New("encryption key cannot be all zeros") + +// GenerateKey generates a random 32-byte key to be used by EncryptWithKey and DecryptWithKey +func GenerateKey() ([]byte, error) { + key := make([]byte, keySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return nil, fmt.Errorf("failed to generate key: %w", err) + } + return key, nil +} + +// EncryptWithKey encrypts data using NaCl's secretbox with the provided key. +// 40 bytes of overhead (24-byte nonce + 16-byte MAC) are added to the plaintext size. +func EncryptWithKey(data, key []byte) ([]byte, error) { + if err := validateKey(key); err != nil { + return nil, err + } + + var secretKey [keySize]byte + copy(secretKey[:], key) + + // Generate 24 bytes of random data as nonce + var nonce [24]byte + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { + return nil, err + } + + // Encrypt and authenticate data + encrypted := secretbox.Seal(nonce[:], data, &nonce, &secretKey) + return encrypted, nil +} + +// DecryptStringWithKey decrypts data using secretbox with the provided key +func DecryptStringWithKey(data, key []byte) ([]byte, error) { + if err := validateKey(key); err != nil { + return nil, err + } + + if len(data) < 24 { + return nil, ErrShortCipherText + } + + var secretKey [keySize]byte + copy(secretKey[:], key) + + var nonce [24]byte + copy(nonce[:], data[:24]) + + decrypted, ok := secretbox.Open(nil, data[24:], &nonce, &secretKey) + if !ok { + return nil, errors.New("decryption failed") + } + + return decrypted, nil +} + +func validateKey(key []byte) error { + if key == nil { + return ErrNoEncryptionKey + } + + if len(key) != keySize { + return ErrInvalidKeySize + } + + // check if key is all zeros + for _, b := range key { + if b != 0 { + return nil + } + } + + return ErrZeroKey +} diff --git a/pkg/meta/internal/crypto/aes_test.go b/pkg/meta/internal/crypto/secretbox_test.go similarity index 53% rename from pkg/meta/internal/crypto/aes_test.go rename to pkg/meta/internal/crypto/secretbox_test.go index 1d0d3e4..d87f860 100644 --- a/pkg/meta/internal/crypto/aes_test.go +++ b/pkg/meta/internal/crypto/secretbox_test.go @@ -8,10 +8,10 @@ import ( "github.com/stretchr/testify/require" ) -func TestAESEncryption(t *testing.T) { +func TestSecretBoxEncryption(t *testing.T) { t.Parallel() - key := make([]byte, 32) // generated random 32-byte key + key := make([]byte, keySize) // generate random 32-byte key _, errKey := rand.Read(key) require.NoError(t, errKey) @@ -40,13 +40,13 @@ func TestAESEncryption(t *testing.T) { { name: "invalid key size", data: []byte("hello world"), - key: make([]byte, 31), + key: make([]byte, 16), // Only 32 bytes allowed now wantErr: ErrInvalidKeySize, }, { name: "zero key returns error", data: []byte("hello world"), - key: make([]byte, 32), + key: make([]byte, keySize), wantErr: ErrZeroKey, }, } @@ -56,24 +56,22 @@ func TestAESEncryption(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - encrypted, err := EncryptWithAESKey(tt.data, tt.key) + encrypted, err := EncryptWithKey(tt.data, tt.key) if tt.wantErr != nil { require.ErrorIs(t, err, tt.wantErr) - return } require.NoError(t, err) - decrypted, err := DecryptStringWithAESKey(encrypted, tt.key) - require.NoError(t, err) - - if tt.key == nil { - require.Equal(t, tt.data, encrypted) - require.Equal(t, tt.data, decrypted) - } else { - require.NotEqual(t, tt.data, encrypted) - require.True(t, bytes.Equal(tt.data, decrypted)) + // Verify encrypted data is different and includes nonce + require.Greater(t, len(encrypted), 24) // At least nonce size + if len(tt.data) > 0 { + require.NotEqual(t, tt.data, encrypted[24:]) // Ignore nonce prefix } + + decrypted, err := DecryptStringWithKey(encrypted, tt.key) + require.NoError(t, err) + require.True(t, bytes.Equal(tt.data, decrypted)) }) } } @@ -81,10 +79,15 @@ func TestAESEncryption(t *testing.T) { func TestDecryptionErrors(t *testing.T) { t.Parallel() - key := make([]byte, 32) + key := make([]byte, keySize) _, err := rand.Read(key) require.NoError(t, err) + // Create valid encrypted data for tampering tests + validData := []byte("test message") + encrypted, err := EncryptWithKey(validData, key) + require.NoError(t, err) + tests := []struct { name string data []byte @@ -93,19 +96,25 @@ func TestDecryptionErrors(t *testing.T) { }{ { name: "short ciphertext", - data: []byte("short"), + data: make([]byte, 23), // Less than nonce size key: key, errMsg: "ciphertext too short", }, { name: "invalid ciphertext", - data: make([]byte, 16), // just nonce size + data: make([]byte, 24), // Just nonce size key: key, - errMsg: "message authentication failed", + errMsg: "decryption failed", + }, + { + name: "tampered ciphertext", + data: tamperWithBytes(encrypted), + key: key, + errMsg: "decryption failed", }, { name: "missing key", - data: []byte("�`M���l\u001AIF�\u0012���=h�?�c� ��\u0012����\u001C�\u0018Ƽ(g"), + data: encrypted, key: nil, errMsg: "encryption key is required", }, @@ -116,9 +125,20 @@ func TestDecryptionErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := DecryptStringWithAESKey(tt.data, tt.key) + _, err := DecryptStringWithKey(tt.data, tt.key) require.Error(t, err) require.Contains(t, err.Error(), tt.errMsg) }) } } + +// tamperWithBytes modifies a byte in the encrypted data to simulate tampering +func tamperWithBytes(data []byte) []byte { + if len(data) < 25 { // Need at least nonce + 1 byte + return data + } + tampered := make([]byte, len(data)) + copy(tampered, data) + tampered[24] ^= 0x01 // Modify first byte after nonce + return tampered +} diff --git a/pkg/meta/meta.go b/pkg/meta/meta.go index 3c54738..9b0e79f 100644 --- a/pkg/meta/meta.go +++ b/pkg/meta/meta.go @@ -63,7 +63,7 @@ func (m *Meta) GetEncryptedString(key string, encryptionKey []byte) (string, err return "", err } - decrypted, err := crypto.DecryptStringWithAESKey(v, encryptionKey) + decrypted, err := crypto.DecryptStringWithKey(v, encryptionKey) if err != nil { return "", err } @@ -111,7 +111,7 @@ func (m *Meta) GetEncryptedBytes(key string, encryptionKey []byte) ([]byte, erro return nil, err } - decrypted, err := crypto.DecryptStringWithAESKey(v, encryptionKey) + decrypted, err := crypto.DecryptStringWithKey(v, encryptionKey) if err != nil { return nil, err } @@ -150,18 +150,19 @@ func (m *Meta) Add(key string, val any) error { // AddEncrypted adds a key/value pair in the meta set. // The value is encrypted with the given encryptionKey. // Accepted types for the value are: string, []byte. +// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead. func (m *Meta) AddEncrypted(key string, val any, encryptionKey []byte) error { var encrypted []byte var err error switch val := val.(type) { case string: - encrypted, err = crypto.EncryptWithAESKey([]byte(val), encryptionKey) + encrypted, err = crypto.EncryptWithKey([]byte(val), encryptionKey) if err != nil { return err } case []byte: - encrypted, err = crypto.EncryptWithAESKey(val, encryptionKey) + encrypted, err = crypto.EncryptWithKey(val, encryptionKey) if err != nil { return err } diff --git a/pkg/policy/ipld.go b/pkg/policy/ipld.go index 9d52d4d..752dd96 100644 --- a/pkg/policy/ipld.go +++ b/pkg/policy/ipld.go @@ -9,10 +9,15 @@ import ( "github.com/ipld/go-ipld-prime/must" "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" "github.com/ucan-wg/go-ucan/pkg/policy/selector" ) func FromIPLD(node datamodel.Node) (Policy, error) { + if err := limits.ValidateIntegerBoundsIPLD(node); err != nil { + return nil, fmt.Errorf("policy contains integer values outside safe bounds: %w", err) + } + return statementsFromIPLD("/", node) } diff --git a/pkg/policy/limits/int.go b/pkg/policy/limits/int.go new file mode 100644 index 0000000..91dd135 --- /dev/null +++ b/pkg/policy/limits/int.go @@ -0,0 +1,49 @@ +package limits + +import ( + "fmt" + + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/must" +) + +const ( + // MaxInt53 represents the maximum safe integer in JavaScript (2^53 - 1) + MaxInt53 = 9007199254740991 + // MinInt53 represents the minimum safe integer in JavaScript (-2^53 + 1) + MinInt53 = -9007199254740991 +) + +func ValidateIntegerBoundsIPLD(node ipld.Node) error { + switch node.Kind() { + case ipld.Kind_Int: + val := must.Int(node) + if val > MaxInt53 || val < MinInt53 { + return fmt.Errorf("integer value %d exceeds safe bounds", val) + } + case ipld.Kind_List: + it := node.ListIterator() + for !it.Done() { + _, v, err := it.Next() + if err != nil { + return err + } + if err := ValidateIntegerBoundsIPLD(v); err != nil { + return err + } + } + case ipld.Kind_Map: + it := node.MapIterator() + for !it.Done() { + _, v, err := it.Next() + if err != nil { + return err + } + if err := ValidateIntegerBoundsIPLD(v); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/policy/limits/int_test.go b/pkg/policy/limits/int_test.go new file mode 100644 index 0000000..de3f21e --- /dev/null +++ b/pkg/policy/limits/int_test.go @@ -0,0 +1,82 @@ +package limits + +import ( + "testing" + + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/stretchr/testify/require" +) + +func TestValidateIntegerBoundsIPLD(t *testing.T) { + buildMap := func() datamodel.Node { + nb := basicnode.Prototype.Any.NewBuilder() + qp.Map(1, func(ma datamodel.MapAssembler) { + qp.MapEntry(ma, "foo", qp.Int(MaxInt53+1)) + })(nb) + return nb.Build() + } + + buildList := func() datamodel.Node { + nb := basicnode.Prototype.Any.NewBuilder() + qp.List(1, func(la datamodel.ListAssembler) { + qp.ListEntry(la, qp.Int(MinInt53-1)) + })(nb) + return nb.Build() + } + + tests := []struct { + name string + input datamodel.Node + wantErr bool + }{ + { + name: "valid int", + input: basicnode.NewInt(42), + wantErr: false, + }, + { + name: "max safe int", + input: basicnode.NewInt(MaxInt53), + wantErr: false, + }, + { + name: "min safe int", + input: basicnode.NewInt(MinInt53), + wantErr: false, + }, + { + name: "above MaxInt53", + input: basicnode.NewInt(MaxInt53 + 1), + wantErr: true, + }, + { + name: "below MinInt53", + input: basicnode.NewInt(MinInt53 - 1), + wantErr: true, + }, + { + name: "nested map with invalid int", + input: buildMap(), + wantErr: true, + }, + { + name: "nested list with invalid int", + input: buildList(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateIntegerBoundsIPLD(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "exceeds safe bounds") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index 33b0904..333ade3 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -12,6 +12,8 @@ import ( "github.com/ipld/go-ipld-prime/fluent/qp" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" + + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) var Bool = basicnode.NewBool @@ -58,8 +60,6 @@ func List[T any](l []T) (ipld.Node, error) { // Any creates an IPLD node from any value // If possible, use another dedicated function for your type for performance. func Any(v any) (res ipld.Node, err error) { - // TODO: handle uint overflow below - // some fast path switch val := v.(type) { case bool: @@ -67,7 +67,11 @@ func Any(v any) (res ipld.Node, err error) { case string: return basicnode.NewString(val), nil case int: - return basicnode.NewInt(int64(val)), nil + i := int64(val) + if i > limits.MaxInt53 || i < limits.MinInt53 { + return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", i) + } + return basicnode.NewInt(i), nil case int8: return basicnode.NewInt(int64(val)), nil case int16: @@ -75,6 +79,9 @@ func Any(v any) (res ipld.Node, err error) { case int32: return basicnode.NewInt(int64(val)), nil case int64: + if val > limits.MaxInt53 || val < limits.MinInt53 { + return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", val) + } return basicnode.NewInt(val), nil case uint: return basicnode.NewInt(int64(val)), nil @@ -85,6 +92,9 @@ func Any(v any) (res ipld.Node, err error) { case uint32: return basicnode.NewInt(int64(val)), nil case uint64: + if val > uint64(limits.MaxInt53) { + return nil, fmt.Errorf("unsigned integer value %d exceeds safe integer bounds", val) + } return basicnode.NewInt(int64(val)), nil case float32: return basicnode.NewFloat(float64(val)), nil @@ -168,9 +178,17 @@ func anyAssemble(val any) qp.Assemble { case reflect.Bool: return qp.Bool(rv.Bool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return qp.Int(rv.Int()) + i := rv.Int() + if i > limits.MaxInt53 || i < limits.MinInt53 { + panic(fmt.Sprintf("integer %d exceeds safe bounds", i)) + } + return qp.Int(i) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return qp.Int(int64(rv.Uint())) + u := rv.Uint() + if u > limits.MaxInt53 { + panic(fmt.Sprintf("unsigned integer %d exceeds safe bounds", u)) + } + return qp.Int(int64(u)) case reflect.Float32, reflect.Float64: return qp.Float(rv.Float()) case reflect.String: diff --git a/pkg/policy/literal/literal_test.go b/pkg/policy/literal/literal_test.go index 45d7c6c..9f7e1f2 100644 --- a/pkg/policy/literal/literal_test.go +++ b/pkg/policy/literal/literal_test.go @@ -8,6 +8,8 @@ import ( cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/printer" "github.com/stretchr/testify/require" + + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) func TestList(t *testing.T) { @@ -214,7 +216,7 @@ func TestAny(t *testing.T) { require.NoError(t, err) require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"))) - v, err = Any(data["func"]) + _, err = Any(data["func"]) require.Error(t, err) } @@ -254,6 +256,56 @@ func BenchmarkAny(b *testing.B) { }) } +func TestAnyAssembleIntegerOverflow(t *testing.T) { + tests := []struct { + name string + input interface{} + shouldErr bool + }{ + { + name: "valid int", + input: 42, + shouldErr: false, + }, + { + name: "max safe int", + input: limits.MaxInt53, + shouldErr: false, + }, + { + name: "min safe int", + input: limits.MinInt53, + shouldErr: false, + }, + { + name: "overflow int", + input: int64(limits.MaxInt53 + 1), + shouldErr: true, + }, + { + name: "underflow int", + input: int64(limits.MinInt53 - 1), + shouldErr: true, + }, + { + name: "overflow uint", + input: uint64(limits.MaxInt53 + 1), + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Any(tt.input) + if tt.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + func must[T any](t T, err error) T { if err != nil { panic(err) diff --git a/pkg/policy/match.go b/pkg/policy/match.go index 59316ed..648b877 100644 --- a/pkg/policy/match.go +++ b/pkg/policy/match.go @@ -3,6 +3,7 @@ package policy import ( "cmp" "fmt" + "math" "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" @@ -249,10 +250,22 @@ func matchStatement(cur Statement, node ipld.Node) (_ matchResult, leafMost Stat panic(fmt.Errorf("unimplemented statement kind: %s", cur.Kind())) } +// isOrdered compares two IPLD nodes and returns true if they satisfy the given ordering function. +// It supports comparison of integers and floats, returning false for: +// - Nodes of different or unsupported kinds +// - Integer values outside JavaScript's safe integer bounds (±2^53-1) +// - Non-finite floating point values (NaN or ±Inf) +// +// The satisfies parameter is a function that interprets the comparison result: +// - For ">" it returns true when order is 1 +// - For ">=" it returns true when order is 0 or 1 +// - For "<" it returns true when order is -1 +// - For "<=" it returns true when order is -1 or 0 func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool { if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int { a := must.Int(actual) b := must.Int(expected) + return satisfies(cmp.Compare(a, b)) } @@ -265,6 +278,11 @@ func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) b if err != nil { panic(fmt.Errorf("extracting selector float: %w", err)) } + + if math.IsInf(a, 0) || math.IsNaN(a) || math.IsInf(b, 0) || math.IsNaN(b) { + return false + } + return satisfies(cmp.Compare(a, b)) } diff --git a/pkg/policy/selector/parsing.go b/pkg/policy/selector/parsing.go index 2e758af..e6deab8 100644 --- a/pkg/policy/selector/parsing.go +++ b/pkg/policy/selector/parsing.go @@ -6,6 +6,8 @@ import ( "regexp" "strconv" "strings" + + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) var ( @@ -67,6 +69,9 @@ func Parse(str string) (Selector, error) { if err != nil { return nil, newParseError("invalid index", str, col, tok) } + if idx > limits.MaxInt53 || idx < limits.MinInt53 { + return nil, newParseError(fmt.Sprintf("index %d exceeds safe integer bounds", idx), str, col, tok) + } sel = append(sel, segment{str: tok, optional: opt, index: idx}) // explicit field, ["abcd"] @@ -88,6 +93,9 @@ func Parse(str string) (Selector, error) { if err != nil { return nil, newParseError("invalid slice index", str, col, tok) } + if i > limits.MaxInt53 || i < limits.MinInt53 { + return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok) + } rng[0] = i } if splt[1] == "" { @@ -97,6 +105,9 @@ func Parse(str string) (Selector, error) { if err != nil { return nil, newParseError("invalid slice index", str, col, tok) } + if i > limits.MaxInt53 || i < limits.MinInt53 { + return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok) + } rng[1] = i } sel = append(sel, segment{str: tok, optional: opt, slice: rng[:]}) diff --git a/pkg/policy/selector/parsing_test.go b/pkg/policy/selector/parsing_test.go index b7ff22d..e58e9b2 100644 --- a/pkg/policy/selector/parsing_test.go +++ b/pkg/policy/selector/parsing_test.go @@ -1,10 +1,12 @@ package selector import ( + "fmt" "math" "testing" "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) func TestParse(t *testing.T) { @@ -616,4 +618,23 @@ func TestParse(t *testing.T) { require.Nil(t, sel) } }) + + t.Run("integer overflow", func(t *testing.T) { + sel, err := Parse(fmt.Sprintf(".[%d]", limits.MaxInt53+1)) + require.Error(t, err) + require.Nil(t, sel) + + sel, err = Parse(fmt.Sprintf(".[%d]", limits.MinInt53-1)) + require.Error(t, err) + require.Nil(t, sel) + + // Test slice overflow + sel, err = Parse(fmt.Sprintf(".[%d:42]", limits.MaxInt53+1)) + require.Error(t, err) + require.Nil(t, sel) + + sel, err = Parse(fmt.Sprintf(".[1:%d]", limits.MaxInt53+1)) + require.Error(t, err) + require.Nil(t, sel) + }) } diff --git a/pkg/policy/selector/selector.go b/pkg/policy/selector/selector.go index 149078d..249cd44 100644 --- a/pkg/policy/selector/selector.go +++ b/pkg/policy/selector/selector.go @@ -266,19 +266,32 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) { case slice[0] == math.MinInt: start = 0 case slice[0] < 0: - start = length + slice[0] + // Check for potential overflow before adding + if -slice[0] > length { + start = 0 + } else { + start = length + slice[0] + } } + switch { case slice[1] == math.MaxInt: end = length case slice[1] < 0: - end = length + slice[1] + // Check for potential overflow before adding + if -slice[1] > length { + end = 0 + } else { + end = length + slice[1] + } } // backward iteration is not allowed, shortcut to an empty result if start >= end { start, end = 0, 0 + return } + // clamp out of bound if start < 0 { start = 0 @@ -286,11 +299,14 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) { if start > length { start = length } + if end < 0 { + end = 0 + } if end > length { end = length } - return start, end + return } func kindString(n datamodel.Node) string { diff --git a/pkg/policy/selector/selector_test.go b/pkg/policy/selector/selector_test.go index 184b7b3..fdd18ec 100644 --- a/pkg/policy/selector/selector_test.go +++ b/pkg/policy/selector/selector_test.go @@ -2,6 +2,7 @@ package selector import ( "errors" + "math" "strings" "testing" @@ -356,3 +357,57 @@ func FuzzParseAndSelect(f *testing.F) { } }) } + +func TestResolveSliceIndices(t *testing.T) { + tests := []struct { + name string + slice []int64 + length int64 + wantStart int64 + wantEnd int64 + }{ + { + name: "normal case", + slice: []int64{1, 3}, + length: 5, + wantStart: 1, + wantEnd: 3, + }, + { + name: "negative indices", + slice: []int64{-2, -1}, + length: 5, + wantStart: 3, + wantEnd: 4, + }, + { + name: "overflow protection negative start", + slice: []int64{math.MinInt64, 3}, + length: 5, + wantStart: 0, + wantEnd: 3, + }, + { + name: "overflow protection negative end", + slice: []int64{0, math.MinInt64}, + length: 5, + wantStart: 0, + wantEnd: 0, + }, + { + name: "max bounds", + slice: []int64{0, math.MaxInt64}, + length: 5, + wantStart: 0, + wantEnd: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start, end := resolveSliceIndices(tt.slice, tt.length) + require.Equal(t, tt.wantStart, start) + require.Equal(t, tt.wantEnd, end) + }) + } +} diff --git a/token/delegation/delegation.go b/token/delegation/delegation.go index 959a40a..110b8e1 100644 --- a/token/delegation/delegation.go +++ b/token/delegation/delegation.go @@ -215,8 +215,15 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) { tkn.meta = m.Meta - tkn.notBefore = parse.OptionalTimestamp(m.Nbf) - tkn.expiration = parse.OptionalTimestamp(m.Exp) + tkn.notBefore, err = parse.OptionalTimestamp(m.Nbf) + if err != nil { + return nil, fmt.Errorf("parse notBefore: %w", err) + } + + tkn.expiration, err = parse.OptionalTimestamp(m.Exp) + if err != nil { + return nil, fmt.Errorf("parse expiration: %w", err) + } if err := tkn.validate(); err != nil { return nil, err diff --git a/token/delegation/options.go b/token/delegation/options.go index 4df14e7..3348760 100644 --- a/token/delegation/options.go +++ b/token/delegation/options.go @@ -45,7 +45,8 @@ func WithMeta(key string, val any) Option { } // WithEncryptedMetaString adds a key/value pair in the "meta" field. -// The string value is encrypted with the given aesKey. +// The string value is encrypted with the given key. +// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead. func WithEncryptedMetaString(key, val string, encryptionKey []byte) Option { return func(t *Token) error { return t.meta.AddEncrypted(key, val, encryptionKey) @@ -53,7 +54,8 @@ func WithEncryptedMetaString(key, val string, encryptionKey []byte) Option { } // WithEncryptedMetaBytes adds a key/value pair in the "meta" field. -// The []byte value is encrypted with the given aesKey. +// The []byte value is encrypted with the given key. +// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead. func WithEncryptedMetaBytes(key string, val, encryptionKey []byte) Option { return func(t *Token) error { return t.meta.AddEncrypted(key, val, encryptionKey) diff --git a/token/internal/parse/parse.go b/token/internal/parse/parse.go index 147b308..27af240 100644 --- a/token/internal/parse/parse.go +++ b/token/internal/parse/parse.go @@ -1,9 +1,11 @@ package parse import ( + "fmt" "time" "github.com/ucan-wg/go-ucan/did" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) func OptionalDID(s *string) (did.DID, error) { @@ -13,10 +15,15 @@ func OptionalDID(s *string) (did.DID, error) { return did.Parse(*s) } -func OptionalTimestamp(sec *int64) *time.Time { +func OptionalTimestamp(sec *int64) (*time.Time, error) { if sec == nil { - return nil + return nil, nil } + + if *sec > limits.MaxInt53 || *sec < limits.MinInt53 { + return nil, fmt.Errorf("timestamp value %d exceeds safe integer bounds", *sec) + } + t := time.Unix(*sec, 0) - return &t + return &t, nil } diff --git a/token/internal/parse/parse_test.go b/token/internal/parse/parse_test.go new file mode 100644 index 0000000..9db6474 --- /dev/null +++ b/token/internal/parse/parse_test.go @@ -0,0 +1,64 @@ +package parse + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" +) + +func TestOptionalTimestamp(t *testing.T) { + tests := []struct { + name string + input *int64 + wantErr bool + }{ + { + name: "nil timestamp", + input: nil, + wantErr: false, + }, + { + name: "valid timestamp", + input: int64Ptr(1625097600), + wantErr: false, + }, + { + name: "max safe integer", + input: int64Ptr(limits.MaxInt53), + wantErr: false, + }, + { + name: "exceeds max safe integer", + input: int64Ptr(limits.MaxInt53 + 1), + wantErr: true, + }, + { + name: "below min safe integer", + input: int64Ptr(limits.MinInt53 - 1), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := OptionalTimestamp(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "exceeds safe integer bounds") + require.Nil(t, result) + } else { + require.NoError(t, err) + if tt.input == nil { + require.Nil(t, result) + } else { + require.NotNil(t, result) + } + } + }) + } +} + +func int64Ptr(i int64) *int64 { + return &i +} diff --git a/token/invocation/invocation.go b/token/invocation/invocation.go index ee85a94..4ab7b8b 100644 --- a/token/invocation/invocation.go +++ b/token/invocation/invocation.go @@ -272,11 +272,22 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) { tkn.nonce = m.Nonce tkn.arguments = m.Args + if err := tkn.arguments.Validate(); err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + tkn.proof = m.Prf tkn.meta = m.Meta - tkn.expiration = parse.OptionalTimestamp(m.Exp) - tkn.invokedAt = parse.OptionalTimestamp(m.Iat) + tkn.expiration, err = parse.OptionalTimestamp(m.Exp) + if err != nil { + return nil, fmt.Errorf("parse expiration: %w", err) + } + + tkn.invokedAt, err = parse.OptionalTimestamp(m.Iat) + if err != nil { + return nil, fmt.Errorf("parse invokedAt: %w", err) + } tkn.cause = m.Cause