Merge branch 'main' into fix/extended-field-names
# Conflicts: # pkg/policy/selector/parsing_test.go
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/ipld/go-ipld-prime/node/basicnode"
|
||||
"github.com/ipld/go-ipld-prime/printer"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
|
||||
)
|
||||
|
||||
@@ -62,6 +63,10 @@ func (a *Args) Add(key string, val any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := limits.ValidateIntegerBoundsIPLD(node); err != nil {
|
||||
return fmt.Errorf("value for key %q: %w", key, err)
|
||||
}
|
||||
|
||||
a.Values[key] = node
|
||||
a.Keys = append(a.Keys, key)
|
||||
|
||||
@@ -164,3 +169,14 @@ func (a *Args) Clone() *Args {
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Validate checks that all values in the Args are valid according to UCAN specs
|
||||
func (a *Args) Validate() error {
|
||||
for key, value := range a.Values {
|
||||
if err := limits.ValidateIntegerBoundsIPLD(value); err != nil {
|
||||
return fmt.Errorf("value for key %q: %w", key, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/pkg/args"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
|
||||
)
|
||||
|
||||
@@ -185,6 +186,71 @@ func TestInclude(t *testing.T) {
|
||||
}, maps.Collect(a1.Iter()))
|
||||
}
|
||||
|
||||
func TestArgsIntegerBounds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val int64
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid int",
|
||||
key: "valid",
|
||||
val: 42,
|
||||
},
|
||||
{
|
||||
name: "max safe integer",
|
||||
key: "max",
|
||||
val: limits.MaxInt53,
|
||||
},
|
||||
{
|
||||
name: "min safe integer",
|
||||
key: "min",
|
||||
val: limits.MinInt53,
|
||||
},
|
||||
{
|
||||
name: "exceeds max safe integer",
|
||||
key: "tooBig",
|
||||
val: limits.MaxInt53 + 1,
|
||||
wantErr: "exceeds safe integer bounds",
|
||||
},
|
||||
{
|
||||
name: "below min safe integer",
|
||||
key: "tooSmall",
|
||||
val: limits.MinInt53 - 1,
|
||||
wantErr: "exceeds safe integer bounds",
|
||||
},
|
||||
{
|
||||
name: "duplicate key",
|
||||
key: "duplicate",
|
||||
val: 42,
|
||||
wantErr: "duplicate key",
|
||||
},
|
||||
}
|
||||
|
||||
a := args.New()
|
||||
require.NoError(t, a.Add("duplicate", 1)) // tests duplicate key
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := a.Add(tt.key, tt.val)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
val, err := a.GetNode(tt.key)
|
||||
require.NoError(t, err)
|
||||
i, err := val.AsInt()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.val, i)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
argsSchema = "type Args { String : Any }"
|
||||
argsName = "Args"
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// KeySize represents valid AES key sizes
|
||||
type KeySize int
|
||||
|
||||
const (
|
||||
KeySize128 KeySize = 16 // AES-128
|
||||
KeySize192 KeySize = 24 // AES-192
|
||||
KeySize256 KeySize = 32 // AES-256 (recommended)
|
||||
)
|
||||
|
||||
// IsValid returns true if the key size is valid for AES
|
||||
func (ks KeySize) IsValid() bool {
|
||||
switch ks {
|
||||
case KeySize128, KeySize192, KeySize256:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var ErrShortCipherText = errors.New("ciphertext too short")
|
||||
var ErrNoEncryptionKey = errors.New("encryption key is required")
|
||||
var ErrInvalidKeySize = errors.New("invalid key size: must be 16, 24, or 32 bytes")
|
||||
var ErrZeroKey = errors.New("encryption key cannot be all zeros")
|
||||
|
||||
// GenerateKey generates a random AES key of default size KeySize256 (32 bytes).
|
||||
// Returns an error if the specified size is invalid or if key generation fails.
|
||||
func GenerateKey() ([]byte, error) {
|
||||
return GenerateKeyWithSize(KeySize256)
|
||||
}
|
||||
|
||||
// GenerateKeyWithSize generates a random AES key of the specified size.
|
||||
// Returns an error if the specified size is invalid or if key generation fails.
|
||||
func GenerateKeyWithSize(size KeySize) ([]byte, error) {
|
||||
if !size.IsValid() {
|
||||
return nil, ErrInvalidKeySize
|
||||
}
|
||||
|
||||
key := make([]byte, size)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate AES key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// EncryptWithAESKey encrypts data using AES-GCM with the provided key.
|
||||
// The key must be 16, 24, or 32 bytes long (for AES-128, AES-192, or AES-256).
|
||||
// Returns the encrypted data with the nonce prepended, or an error if encryption fails.
|
||||
func EncryptWithAESKey(data, key []byte) ([]byte, error) {
|
||||
if err := validateAESKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
// DecryptStringWithAESKey decrypts data that was encrypted with EncryptWithAESKey.
|
||||
// The key must match the one used for encryption.
|
||||
// Expects the input to have a prepended nonce.
|
||||
// Returns the decrypted data or an error if decryption fails.
|
||||
func DecryptStringWithAESKey(data, key []byte) ([]byte, error) {
|
||||
if err := validateAESKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(data) < gcm.NonceSize() {
|
||||
return nil, ErrShortCipherText
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:gcm.NonceSize()], data[gcm.NonceSize():]
|
||||
decrypted, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func validateAESKey(key []byte) error {
|
||||
if key == nil {
|
||||
return ErrNoEncryptionKey
|
||||
}
|
||||
|
||||
if !KeySize(len(key)).IsValid() {
|
||||
return ErrInvalidKeySize
|
||||
}
|
||||
|
||||
// check if key is all zeros
|
||||
for _, b := range key {
|
||||
if b != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrZeroKey
|
||||
}
|
||||
90
pkg/meta/internal/crypto/secretbox.go
Normal file
90
pkg/meta/internal/crypto/secretbox.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/nacl/secretbox"
|
||||
)
|
||||
|
||||
const keySize = 32 // secretbox allows only 32-byte keys
|
||||
|
||||
var ErrShortCipherText = errors.New("ciphertext too short")
|
||||
var ErrNoEncryptionKey = errors.New("encryption key is required")
|
||||
var ErrInvalidKeySize = errors.New("invalid key size: must be 32 bytes")
|
||||
var ErrZeroKey = errors.New("encryption key cannot be all zeros")
|
||||
|
||||
// GenerateKey generates a random 32-byte key to be used by EncryptWithKey and DecryptWithKey
|
||||
func GenerateKey() ([]byte, error) {
|
||||
key := make([]byte, keySize)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// EncryptWithKey encrypts data using NaCl's secretbox with the provided key.
|
||||
// 40 bytes of overhead (24-byte nonce + 16-byte MAC) are added to the plaintext size.
|
||||
func EncryptWithKey(data, key []byte) ([]byte, error) {
|
||||
if err := validateKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var secretKey [keySize]byte
|
||||
copy(secretKey[:], key)
|
||||
|
||||
// Generate 24 bytes of random data as nonce
|
||||
var nonce [24]byte
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encrypt and authenticate data
|
||||
encrypted := secretbox.Seal(nonce[:], data, &nonce, &secretKey)
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// DecryptStringWithKey decrypts data using secretbox with the provided key
|
||||
func DecryptStringWithKey(data, key []byte) ([]byte, error) {
|
||||
if err := validateKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(data) < 24 {
|
||||
return nil, ErrShortCipherText
|
||||
}
|
||||
|
||||
var secretKey [keySize]byte
|
||||
copy(secretKey[:], key)
|
||||
|
||||
var nonce [24]byte
|
||||
copy(nonce[:], data[:24])
|
||||
|
||||
decrypted, ok := secretbox.Open(nil, data[24:], &nonce, &secretKey)
|
||||
if !ok {
|
||||
return nil, errors.New("decryption failed")
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func validateKey(key []byte) error {
|
||||
if key == nil {
|
||||
return ErrNoEncryptionKey
|
||||
}
|
||||
|
||||
if len(key) != keySize {
|
||||
return ErrInvalidKeySize
|
||||
}
|
||||
|
||||
// check if key is all zeros
|
||||
for _, b := range key {
|
||||
if b != 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return ErrZeroKey
|
||||
}
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAESEncryption(t *testing.T) {
|
||||
func TestSecretBoxEncryption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key := make([]byte, 32) // generated random 32-byte key
|
||||
key := make([]byte, keySize) // generate random 32-byte key
|
||||
_, errKey := rand.Read(key)
|
||||
require.NoError(t, errKey)
|
||||
|
||||
@@ -40,13 +40,13 @@ func TestAESEncryption(t *testing.T) {
|
||||
{
|
||||
name: "invalid key size",
|
||||
data: []byte("hello world"),
|
||||
key: make([]byte, 31),
|
||||
key: make([]byte, 16), // Only 32 bytes allowed now
|
||||
wantErr: ErrInvalidKeySize,
|
||||
},
|
||||
{
|
||||
name: "zero key returns error",
|
||||
data: []byte("hello world"),
|
||||
key: make([]byte, 32),
|
||||
key: make([]byte, keySize),
|
||||
wantErr: ErrZeroKey,
|
||||
},
|
||||
}
|
||||
@@ -56,24 +56,22 @@ func TestAESEncryption(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
encrypted, err := EncryptWithAESKey(tt.data, tt.key)
|
||||
encrypted, err := EncryptWithKey(tt.data, tt.key)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := DecryptStringWithAESKey(encrypted, tt.key)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.key == nil {
|
||||
require.Equal(t, tt.data, encrypted)
|
||||
require.Equal(t, tt.data, decrypted)
|
||||
} else {
|
||||
require.NotEqual(t, tt.data, encrypted)
|
||||
require.True(t, bytes.Equal(tt.data, decrypted))
|
||||
// Verify encrypted data is different and includes nonce
|
||||
require.Greater(t, len(encrypted), 24) // At least nonce size
|
||||
if len(tt.data) > 0 {
|
||||
require.NotEqual(t, tt.data, encrypted[24:]) // Ignore nonce prefix
|
||||
}
|
||||
|
||||
decrypted, err := DecryptStringWithKey(encrypted, tt.key)
|
||||
require.NoError(t, err)
|
||||
require.True(t, bytes.Equal(tt.data, decrypted))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -81,10 +79,15 @@ func TestAESEncryption(t *testing.T) {
|
||||
func TestDecryptionErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
key := make([]byte, 32)
|
||||
key := make([]byte, keySize)
|
||||
_, err := rand.Read(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create valid encrypted data for tampering tests
|
||||
validData := []byte("test message")
|
||||
encrypted, err := EncryptWithKey(validData, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
@@ -93,19 +96,25 @@ func TestDecryptionErrors(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "short ciphertext",
|
||||
data: []byte("short"),
|
||||
data: make([]byte, 23), // Less than nonce size
|
||||
key: key,
|
||||
errMsg: "ciphertext too short",
|
||||
},
|
||||
{
|
||||
name: "invalid ciphertext",
|
||||
data: make([]byte, 16), // just nonce size
|
||||
data: make([]byte, 24), // Just nonce size
|
||||
key: key,
|
||||
errMsg: "message authentication failed",
|
||||
errMsg: "decryption failed",
|
||||
},
|
||||
{
|
||||
name: "tampered ciphertext",
|
||||
data: tamperWithBytes(encrypted),
|
||||
key: key,
|
||||
errMsg: "decryption failed",
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
data: []byte("<22>`M<><4D><EFBFBD>l\u001AIF<49>\u0012<31><32><EFBFBD>=h<>?<3F>c<EFBFBD> <20><>\u0012<31><32><EFBFBD><EFBFBD>\u001C<31>\u0018Ƽ(g"),
|
||||
data: encrypted,
|
||||
key: nil,
|
||||
errMsg: "encryption key is required",
|
||||
},
|
||||
@@ -116,9 +125,20 @@ func TestDecryptionErrors(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := DecryptStringWithAESKey(tt.data, tt.key)
|
||||
_, err := DecryptStringWithKey(tt.data, tt.key)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// tamperWithBytes modifies a byte in the encrypted data to simulate tampering
|
||||
func tamperWithBytes(data []byte) []byte {
|
||||
if len(data) < 25 { // Need at least nonce + 1 byte
|
||||
return data
|
||||
}
|
||||
tampered := make([]byte, len(data))
|
||||
copy(tampered, data)
|
||||
tampered[24] ^= 0x01 // Modify first byte after nonce
|
||||
return tampered
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func (m *Meta) GetEncryptedString(key string, encryptionKey []byte) (string, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
decrypted, err := crypto.DecryptStringWithAESKey(v, encryptionKey)
|
||||
decrypted, err := crypto.DecryptStringWithKey(v, encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func (m *Meta) GetEncryptedBytes(key string, encryptionKey []byte) ([]byte, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypted, err := crypto.DecryptStringWithAESKey(v, encryptionKey)
|
||||
decrypted, err := crypto.DecryptStringWithKey(v, encryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -150,18 +150,19 @@ func (m *Meta) Add(key string, val any) error {
|
||||
// AddEncrypted adds a key/value pair in the meta set.
|
||||
// The value is encrypted with the given encryptionKey.
|
||||
// Accepted types for the value are: string, []byte.
|
||||
// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead.
|
||||
func (m *Meta) AddEncrypted(key string, val any, encryptionKey []byte) error {
|
||||
var encrypted []byte
|
||||
var err error
|
||||
|
||||
switch val := val.(type) {
|
||||
case string:
|
||||
encrypted, err = crypto.EncryptWithAESKey([]byte(val), encryptionKey)
|
||||
encrypted, err = crypto.EncryptWithKey([]byte(val), encryptionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case []byte:
|
||||
encrypted, err = crypto.EncryptWithAESKey(val, encryptionKey)
|
||||
encrypted, err = crypto.EncryptWithKey(val, encryptionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,10 +9,15 @@ import (
|
||||
"github.com/ipld/go-ipld-prime/must"
|
||||
"github.com/ipld/go-ipld-prime/node/basicnode"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/selector"
|
||||
)
|
||||
|
||||
func FromIPLD(node datamodel.Node) (Policy, error) {
|
||||
if err := limits.ValidateIntegerBoundsIPLD(node); err != nil {
|
||||
return nil, fmt.Errorf("policy contains integer values outside safe bounds: %w", err)
|
||||
}
|
||||
|
||||
return statementsFromIPLD("/", node)
|
||||
}
|
||||
|
||||
|
||||
49
pkg/policy/limits/int.go
Normal file
49
pkg/policy/limits/int.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package limits
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ipld/go-ipld-prime"
|
||||
"github.com/ipld/go-ipld-prime/must"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxInt53 represents the maximum safe integer in JavaScript (2^53 - 1)
|
||||
MaxInt53 = 9007199254740991
|
||||
// MinInt53 represents the minimum safe integer in JavaScript (-2^53 + 1)
|
||||
MinInt53 = -9007199254740991
|
||||
)
|
||||
|
||||
func ValidateIntegerBoundsIPLD(node ipld.Node) error {
|
||||
switch node.Kind() {
|
||||
case ipld.Kind_Int:
|
||||
val := must.Int(node)
|
||||
if val > MaxInt53 || val < MinInt53 {
|
||||
return fmt.Errorf("integer value %d exceeds safe bounds", val)
|
||||
}
|
||||
case ipld.Kind_List:
|
||||
it := node.ListIterator()
|
||||
for !it.Done() {
|
||||
_, v, err := it.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateIntegerBoundsIPLD(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case ipld.Kind_Map:
|
||||
it := node.MapIterator()
|
||||
for !it.Done() {
|
||||
_, v, err := it.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateIntegerBoundsIPLD(v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
82
pkg/policy/limits/int_test.go
Normal file
82
pkg/policy/limits/int_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package limits
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ipld/go-ipld-prime/datamodel"
|
||||
"github.com/ipld/go-ipld-prime/fluent/qp"
|
||||
"github.com/ipld/go-ipld-prime/node/basicnode"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateIntegerBoundsIPLD(t *testing.T) {
|
||||
buildMap := func() datamodel.Node {
|
||||
nb := basicnode.Prototype.Any.NewBuilder()
|
||||
qp.Map(1, func(ma datamodel.MapAssembler) {
|
||||
qp.MapEntry(ma, "foo", qp.Int(MaxInt53+1))
|
||||
})(nb)
|
||||
return nb.Build()
|
||||
}
|
||||
|
||||
buildList := func() datamodel.Node {
|
||||
nb := basicnode.Prototype.Any.NewBuilder()
|
||||
qp.List(1, func(la datamodel.ListAssembler) {
|
||||
qp.ListEntry(la, qp.Int(MinInt53-1))
|
||||
})(nb)
|
||||
return nb.Build()
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input datamodel.Node
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid int",
|
||||
input: basicnode.NewInt(42),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "max safe int",
|
||||
input: basicnode.NewInt(MaxInt53),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "min safe int",
|
||||
input: basicnode.NewInt(MinInt53),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "above MaxInt53",
|
||||
input: basicnode.NewInt(MaxInt53 + 1),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "below MinInt53",
|
||||
input: basicnode.NewInt(MinInt53 - 1),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nested map with invalid int",
|
||||
input: buildMap(),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nested list with invalid int",
|
||||
input: buildList(),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateIntegerBoundsIPLD(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "exceeds safe bounds")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"github.com/ipld/go-ipld-prime/fluent/qp"
|
||||
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
|
||||
"github.com/ipld/go-ipld-prime/node/basicnode"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
)
|
||||
|
||||
var Bool = basicnode.NewBool
|
||||
@@ -58,8 +60,6 @@ func List[T any](l []T) (ipld.Node, error) {
|
||||
// Any creates an IPLD node from any value
|
||||
// If possible, use another dedicated function for your type for performance.
|
||||
func Any(v any) (res ipld.Node, err error) {
|
||||
// TODO: handle uint overflow below
|
||||
|
||||
// some fast path
|
||||
switch val := v.(type) {
|
||||
case bool:
|
||||
@@ -67,7 +67,11 @@ func Any(v any) (res ipld.Node, err error) {
|
||||
case string:
|
||||
return basicnode.NewString(val), nil
|
||||
case int:
|
||||
return basicnode.NewInt(int64(val)), nil
|
||||
i := int64(val)
|
||||
if i > limits.MaxInt53 || i < limits.MinInt53 {
|
||||
return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", i)
|
||||
}
|
||||
return basicnode.NewInt(i), nil
|
||||
case int8:
|
||||
return basicnode.NewInt(int64(val)), nil
|
||||
case int16:
|
||||
@@ -75,6 +79,9 @@ func Any(v any) (res ipld.Node, err error) {
|
||||
case int32:
|
||||
return basicnode.NewInt(int64(val)), nil
|
||||
case int64:
|
||||
if val > limits.MaxInt53 || val < limits.MinInt53 {
|
||||
return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", val)
|
||||
}
|
||||
return basicnode.NewInt(val), nil
|
||||
case uint:
|
||||
return basicnode.NewInt(int64(val)), nil
|
||||
@@ -85,6 +92,9 @@ func Any(v any) (res ipld.Node, err error) {
|
||||
case uint32:
|
||||
return basicnode.NewInt(int64(val)), nil
|
||||
case uint64:
|
||||
if val > uint64(limits.MaxInt53) {
|
||||
return nil, fmt.Errorf("unsigned integer value %d exceeds safe integer bounds", val)
|
||||
}
|
||||
return basicnode.NewInt(int64(val)), nil
|
||||
case float32:
|
||||
return basicnode.NewFloat(float64(val)), nil
|
||||
@@ -168,9 +178,17 @@ func anyAssemble(val any) qp.Assemble {
|
||||
case reflect.Bool:
|
||||
return qp.Bool(rv.Bool())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return qp.Int(rv.Int())
|
||||
i := rv.Int()
|
||||
if i > limits.MaxInt53 || i < limits.MinInt53 {
|
||||
panic(fmt.Sprintf("integer %d exceeds safe bounds", i))
|
||||
}
|
||||
return qp.Int(i)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return qp.Int(int64(rv.Uint()))
|
||||
u := rv.Uint()
|
||||
if u > limits.MaxInt53 {
|
||||
panic(fmt.Sprintf("unsigned integer %d exceeds safe bounds", u))
|
||||
}
|
||||
return qp.Int(int64(u))
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return qp.Float(rv.Float())
|
||||
case reflect.String:
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
|
||||
"github.com/ipld/go-ipld-prime/printer"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
)
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
@@ -214,7 +216,7 @@ func TestAny(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm")))
|
||||
|
||||
v, err = Any(data["func"])
|
||||
_, err = Any(data["func"])
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -254,6 +256,56 @@ func BenchmarkAny(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnyAssembleIntegerOverflow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid int",
|
||||
input: 42,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "max safe int",
|
||||
input: limits.MaxInt53,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "min safe int",
|
||||
input: limits.MinInt53,
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "overflow int",
|
||||
input: int64(limits.MaxInt53 + 1),
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "underflow int",
|
||||
input: int64(limits.MinInt53 - 1),
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "overflow uint",
|
||||
input: uint64(limits.MaxInt53 + 1),
|
||||
shouldErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Any(tt.input)
|
||||
if tt.shouldErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func must[T any](t T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -3,6 +3,7 @@ package policy
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ipld/go-ipld-prime"
|
||||
"github.com/ipld/go-ipld-prime/datamodel"
|
||||
@@ -249,10 +250,22 @@ func matchStatement(cur Statement, node ipld.Node) (_ matchResult, leafMost Stat
|
||||
panic(fmt.Errorf("unimplemented statement kind: %s", cur.Kind()))
|
||||
}
|
||||
|
||||
// isOrdered compares two IPLD nodes and returns true if they satisfy the given ordering function.
|
||||
// It supports comparison of integers and floats, returning false for:
|
||||
// - Nodes of different or unsupported kinds
|
||||
// - Integer values outside JavaScript's safe integer bounds (±2^53-1)
|
||||
// - Non-finite floating point values (NaN or ±Inf)
|
||||
//
|
||||
// The satisfies parameter is a function that interprets the comparison result:
|
||||
// - For ">" it returns true when order is 1
|
||||
// - For ">=" it returns true when order is 0 or 1
|
||||
// - For "<" it returns true when order is -1
|
||||
// - For "<=" it returns true when order is -1 or 0
|
||||
func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool {
|
||||
if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int {
|
||||
a := must.Int(actual)
|
||||
b := must.Int(expected)
|
||||
|
||||
return satisfies(cmp.Compare(a, b))
|
||||
}
|
||||
|
||||
@@ -265,6 +278,11 @@ func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) b
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("extracting selector float: %w", err))
|
||||
}
|
||||
|
||||
if math.IsInf(a, 0) || math.IsNaN(a) || math.IsInf(b, 0) || math.IsNaN(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
return satisfies(cmp.Compare(a, b))
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -67,6 +69,9 @@ func Parse(str string) (Selector, error) {
|
||||
if err != nil {
|
||||
return nil, newParseError("invalid index", str, col, tok)
|
||||
}
|
||||
if idx > limits.MaxInt53 || idx < limits.MinInt53 {
|
||||
return nil, newParseError(fmt.Sprintf("index %d exceeds safe integer bounds", idx), str, col, tok)
|
||||
}
|
||||
sel = append(sel, segment{str: tok, optional: opt, index: idx})
|
||||
|
||||
// explicit field, ["abcd"]
|
||||
@@ -88,6 +93,9 @@ func Parse(str string) (Selector, error) {
|
||||
if err != nil {
|
||||
return nil, newParseError("invalid slice index", str, col, tok)
|
||||
}
|
||||
if i > limits.MaxInt53 || i < limits.MinInt53 {
|
||||
return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok)
|
||||
}
|
||||
rng[0] = i
|
||||
}
|
||||
if splt[1] == "" {
|
||||
@@ -97,6 +105,9 @@ func Parse(str string) (Selector, error) {
|
||||
if err != nil {
|
||||
return nil, newParseError("invalid slice index", str, col, tok)
|
||||
}
|
||||
if i > limits.MaxInt53 || i < limits.MinInt53 {
|
||||
return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok)
|
||||
}
|
||||
rng[1] = i
|
||||
}
|
||||
sel = append(sel, segment{str: tok, optional: opt, slice: rng[:]})
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package selector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
@@ -616,4 +618,23 @@ func TestParse(t *testing.T) {
|
||||
require.Nil(t, sel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("integer overflow", func(t *testing.T) {
|
||||
sel, err := Parse(fmt.Sprintf(".[%d]", limits.MaxInt53+1))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, sel)
|
||||
|
||||
sel, err = Parse(fmt.Sprintf(".[%d]", limits.MinInt53-1))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, sel)
|
||||
|
||||
// Test slice overflow
|
||||
sel, err = Parse(fmt.Sprintf(".[%d:42]", limits.MaxInt53+1))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, sel)
|
||||
|
||||
sel, err = Parse(fmt.Sprintf(".[1:%d]", limits.MaxInt53+1))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, sel)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -266,19 +266,32 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) {
|
||||
case slice[0] == math.MinInt:
|
||||
start = 0
|
||||
case slice[0] < 0:
|
||||
start = length + slice[0]
|
||||
// Check for potential overflow before adding
|
||||
if -slice[0] > length {
|
||||
start = 0
|
||||
} else {
|
||||
start = length + slice[0]
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case slice[1] == math.MaxInt:
|
||||
end = length
|
||||
case slice[1] < 0:
|
||||
end = length + slice[1]
|
||||
// Check for potential overflow before adding
|
||||
if -slice[1] > length {
|
||||
end = 0
|
||||
} else {
|
||||
end = length + slice[1]
|
||||
}
|
||||
}
|
||||
|
||||
// backward iteration is not allowed, shortcut to an empty result
|
||||
if start >= end {
|
||||
start, end = 0, 0
|
||||
return
|
||||
}
|
||||
|
||||
// clamp out of bound
|
||||
if start < 0 {
|
||||
start = 0
|
||||
@@ -286,11 +299,14 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) {
|
||||
if start > length {
|
||||
start = length
|
||||
}
|
||||
if end < 0 {
|
||||
end = 0
|
||||
}
|
||||
if end > length {
|
||||
end = length
|
||||
}
|
||||
|
||||
return start, end
|
||||
return
|
||||
}
|
||||
|
||||
func kindString(n datamodel.Node) string {
|
||||
|
||||
@@ -2,6 +2,7 @@ package selector
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -356,3 +357,57 @@ func FuzzParseAndSelect(f *testing.F) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveSliceIndices(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
slice []int64
|
||||
length int64
|
||||
wantStart int64
|
||||
wantEnd int64
|
||||
}{
|
||||
{
|
||||
name: "normal case",
|
||||
slice: []int64{1, 3},
|
||||
length: 5,
|
||||
wantStart: 1,
|
||||
wantEnd: 3,
|
||||
},
|
||||
{
|
||||
name: "negative indices",
|
||||
slice: []int64{-2, -1},
|
||||
length: 5,
|
||||
wantStart: 3,
|
||||
wantEnd: 4,
|
||||
},
|
||||
{
|
||||
name: "overflow protection negative start",
|
||||
slice: []int64{math.MinInt64, 3},
|
||||
length: 5,
|
||||
wantStart: 0,
|
||||
wantEnd: 3,
|
||||
},
|
||||
{
|
||||
name: "overflow protection negative end",
|
||||
slice: []int64{0, math.MinInt64},
|
||||
length: 5,
|
||||
wantStart: 0,
|
||||
wantEnd: 0,
|
||||
},
|
||||
{
|
||||
name: "max bounds",
|
||||
slice: []int64{0, math.MaxInt64},
|
||||
length: 5,
|
||||
wantStart: 0,
|
||||
wantEnd: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
start, end := resolveSliceIndices(tt.slice, tt.length)
|
||||
require.Equal(t, tt.wantStart, start)
|
||||
require.Equal(t, tt.wantEnd, end)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,8 +215,15 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) {
|
||||
|
||||
tkn.meta = m.Meta
|
||||
|
||||
tkn.notBefore = parse.OptionalTimestamp(m.Nbf)
|
||||
tkn.expiration = parse.OptionalTimestamp(m.Exp)
|
||||
tkn.notBefore, err = parse.OptionalTimestamp(m.Nbf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse notBefore: %w", err)
|
||||
}
|
||||
|
||||
tkn.expiration, err = parse.OptionalTimestamp(m.Exp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
if err := tkn.validate(); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -45,7 +45,8 @@ func WithMeta(key string, val any) Option {
|
||||
}
|
||||
|
||||
// WithEncryptedMetaString adds a key/value pair in the "meta" field.
|
||||
// The string value is encrypted with the given aesKey.
|
||||
// The string value is encrypted with the given key.
|
||||
// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead.
|
||||
func WithEncryptedMetaString(key, val string, encryptionKey []byte) Option {
|
||||
return func(t *Token) error {
|
||||
return t.meta.AddEncrypted(key, val, encryptionKey)
|
||||
@@ -53,7 +54,8 @@ func WithEncryptedMetaString(key, val string, encryptionKey []byte) Option {
|
||||
}
|
||||
|
||||
// WithEncryptedMetaBytes adds a key/value pair in the "meta" field.
|
||||
// The []byte value is encrypted with the given aesKey.
|
||||
// The []byte value is encrypted with the given key.
|
||||
// The ciphertext will be 40 bytes larger than the plaintext due to encryption overhead.
|
||||
func WithEncryptedMetaBytes(key string, val, encryptionKey []byte) Option {
|
||||
return func(t *Token) error {
|
||||
return t.meta.AddEncrypted(key, val, encryptionKey)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package parse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ucan-wg/go-ucan/did"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
)
|
||||
|
||||
func OptionalDID(s *string) (did.DID, error) {
|
||||
@@ -13,10 +15,15 @@ func OptionalDID(s *string) (did.DID, error) {
|
||||
return did.Parse(*s)
|
||||
}
|
||||
|
||||
func OptionalTimestamp(sec *int64) *time.Time {
|
||||
func OptionalTimestamp(sec *int64) (*time.Time, error) {
|
||||
if sec == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if *sec > limits.MaxInt53 || *sec < limits.MinInt53 {
|
||||
return nil, fmt.Errorf("timestamp value %d exceeds safe integer bounds", *sec)
|
||||
}
|
||||
|
||||
t := time.Unix(*sec, 0)
|
||||
return &t
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
64
token/internal/parse/parse_test.go
Normal file
64
token/internal/parse/parse_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package parse
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
|
||||
)
|
||||
|
||||
func TestOptionalTimestamp(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input *int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil timestamp",
|
||||
input: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid timestamp",
|
||||
input: int64Ptr(1625097600),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "max safe integer",
|
||||
input: int64Ptr(limits.MaxInt53),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "exceeds max safe integer",
|
||||
input: int64Ptr(limits.MaxInt53 + 1),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "below min safe integer",
|
||||
input: int64Ptr(limits.MinInt53 - 1),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := OptionalTimestamp(tt.input)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "exceeds safe integer bounds")
|
||||
require.Nil(t, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
if tt.input == nil {
|
||||
require.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func int64Ptr(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
@@ -272,11 +272,22 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) {
|
||||
tkn.nonce = m.Nonce
|
||||
|
||||
tkn.arguments = m.Args
|
||||
if err := tkn.arguments.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid arguments: %w", err)
|
||||
}
|
||||
|
||||
tkn.proof = m.Prf
|
||||
tkn.meta = m.Meta
|
||||
|
||||
tkn.expiration = parse.OptionalTimestamp(m.Exp)
|
||||
tkn.invokedAt = parse.OptionalTimestamp(m.Iat)
|
||||
tkn.expiration, err = parse.OptionalTimestamp(m.Exp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
tkn.invokedAt, err = parse.OptionalTimestamp(m.Iat)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse invokedAt: %w", err)
|
||||
}
|
||||
|
||||
tkn.cause = m.Cause
|
||||
|
||||
|
||||
Reference in New Issue
Block a user