Merge pull request #89 from ucan-wg/fix/prevent-int-overflow

fix: prevent overflow of int values
This commit is contained in:
Michael Muré
2024-12-02 14:41:12 +01:00
committed by GitHub
16 changed files with 514 additions and 16 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/ipld/go-ipld-prime/printer"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
)
@@ -62,6 +63,10 @@ func (a *Args) Add(key string, val any) error {
return err
}
if err := limits.ValidateIntegerBoundsIPLD(node); err != nil {
return fmt.Errorf("value for key %q: %w", key, err)
}
a.Values[key] = node
a.Keys = append(a.Keys, key)
@@ -164,3 +169,14 @@ func (a *Args) Clone() *Args {
}
return res
}
// Validate checks that all values in the Args are valid according to UCAN specs
func (a *Args) Validate() error {
for key, value := range a.Values {
if err := limits.ValidateIntegerBoundsIPLD(value); err != nil {
return fmt.Errorf("value for key %q: %w", key, err)
}
}
return nil
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/args"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
"github.com/ucan-wg/go-ucan/pkg/policy/literal"
)
@@ -185,6 +186,71 @@ func TestInclude(t *testing.T) {
}, maps.Collect(a1.Iter()))
}
func TestArgsIntegerBounds(t *testing.T) {
t.Parallel()
tests := []struct {
name string
key string
val int64
wantErr string
}{
{
name: "valid int",
key: "valid",
val: 42,
},
{
name: "max safe integer",
key: "max",
val: limits.MaxInt53,
},
{
name: "min safe integer",
key: "min",
val: limits.MinInt53,
},
{
name: "exceeds max safe integer",
key: "tooBig",
val: limits.MaxInt53 + 1,
wantErr: "exceeds safe integer bounds",
},
{
name: "below min safe integer",
key: "tooSmall",
val: limits.MinInt53 - 1,
wantErr: "exceeds safe integer bounds",
},
{
name: "duplicate key",
key: "duplicate",
val: 42,
wantErr: "duplicate key",
},
}
a := args.New()
require.NoError(t, a.Add("duplicate", 1)) // tests duplicate key
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := a.Add(tt.key, tt.val)
if tt.wantErr != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.wantErr)
} else {
require.NoError(t, err)
val, err := a.GetNode(tt.key)
require.NoError(t, err)
i, err := val.AsInt()
require.NoError(t, err)
require.Equal(t, tt.val, i)
}
})
}
}
const (
argsSchema = "type Args { String : Any }"
argsName = "Args"

View File

@@ -9,10 +9,15 @@ import (
"github.com/ipld/go-ipld-prime/must"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
"github.com/ucan-wg/go-ucan/pkg/policy/selector"
)
func FromIPLD(node datamodel.Node) (Policy, error) {
if err := limits.ValidateIntegerBoundsIPLD(node); err != nil {
return nil, fmt.Errorf("policy contains integer values outside safe bounds: %w", err)
}
return statementsFromIPLD("/", node)
}

49
pkg/policy/limits/int.go Normal file
View File

@@ -0,0 +1,49 @@
package limits
import (
"fmt"
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/must"
)
const (
// MaxInt53 represents the maximum safe integer in JavaScript (2^53 - 1)
MaxInt53 = 9007199254740991
// MinInt53 represents the minimum safe integer in JavaScript (-2^53 + 1)
MinInt53 = -9007199254740991
)
func ValidateIntegerBoundsIPLD(node ipld.Node) error {
switch node.Kind() {
case ipld.Kind_Int:
val := must.Int(node)
if val > MaxInt53 || val < MinInt53 {
return fmt.Errorf("integer value %d exceeds safe bounds", val)
}
case ipld.Kind_List:
it := node.ListIterator()
for !it.Done() {
_, v, err := it.Next()
if err != nil {
return err
}
if err := ValidateIntegerBoundsIPLD(v); err != nil {
return err
}
}
case ipld.Kind_Map:
it := node.MapIterator()
for !it.Done() {
_, v, err := it.Next()
if err != nil {
return err
}
if err := ValidateIntegerBoundsIPLD(v); err != nil {
return err
}
}
}
return nil
}

View File

@@ -0,0 +1,82 @@
package limits
import (
"testing"
"github.com/ipld/go-ipld-prime/datamodel"
"github.com/ipld/go-ipld-prime/fluent/qp"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/stretchr/testify/require"
)
func TestValidateIntegerBoundsIPLD(t *testing.T) {
buildMap := func() datamodel.Node {
nb := basicnode.Prototype.Any.NewBuilder()
qp.Map(1, func(ma datamodel.MapAssembler) {
qp.MapEntry(ma, "foo", qp.Int(MaxInt53+1))
})(nb)
return nb.Build()
}
buildList := func() datamodel.Node {
nb := basicnode.Prototype.Any.NewBuilder()
qp.List(1, func(la datamodel.ListAssembler) {
qp.ListEntry(la, qp.Int(MinInt53-1))
})(nb)
return nb.Build()
}
tests := []struct {
name string
input datamodel.Node
wantErr bool
}{
{
name: "valid int",
input: basicnode.NewInt(42),
wantErr: false,
},
{
name: "max safe int",
input: basicnode.NewInt(MaxInt53),
wantErr: false,
},
{
name: "min safe int",
input: basicnode.NewInt(MinInt53),
wantErr: false,
},
{
name: "above MaxInt53",
input: basicnode.NewInt(MaxInt53 + 1),
wantErr: true,
},
{
name: "below MinInt53",
input: basicnode.NewInt(MinInt53 - 1),
wantErr: true,
},
{
name: "nested map with invalid int",
input: buildMap(),
wantErr: true,
},
{
name: "nested list with invalid int",
input: buildList(),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateIntegerBoundsIPLD(tt.input)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "exceeds safe bounds")
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -12,6 +12,8 @@ import (
"github.com/ipld/go-ipld-prime/fluent/qp"
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
"github.com/ipld/go-ipld-prime/node/basicnode"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
)
var Bool = basicnode.NewBool
@@ -58,8 +60,6 @@ func List[T any](l []T) (ipld.Node, error) {
// Any creates an IPLD node from any value
// If possible, use another dedicated function for your type for performance.
func Any(v any) (res ipld.Node, err error) {
// TODO: handle uint overflow below
// some fast path
switch val := v.(type) {
case bool:
@@ -67,7 +67,11 @@ func Any(v any) (res ipld.Node, err error) {
case string:
return basicnode.NewString(val), nil
case int:
return basicnode.NewInt(int64(val)), nil
i := int64(val)
if i > limits.MaxInt53 || i < limits.MinInt53 {
return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", i)
}
return basicnode.NewInt(i), nil
case int8:
return basicnode.NewInt(int64(val)), nil
case int16:
@@ -75,6 +79,9 @@ func Any(v any) (res ipld.Node, err error) {
case int32:
return basicnode.NewInt(int64(val)), nil
case int64:
if val > limits.MaxInt53 || val < limits.MinInt53 {
return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", val)
}
return basicnode.NewInt(val), nil
case uint:
return basicnode.NewInt(int64(val)), nil
@@ -85,6 +92,9 @@ func Any(v any) (res ipld.Node, err error) {
case uint32:
return basicnode.NewInt(int64(val)), nil
case uint64:
if val > uint64(limits.MaxInt53) {
return nil, fmt.Errorf("unsigned integer value %d exceeds safe integer bounds", val)
}
return basicnode.NewInt(int64(val)), nil
case float32:
return basicnode.NewFloat(float64(val)), nil
@@ -168,9 +178,17 @@ func anyAssemble(val any) qp.Assemble {
case reflect.Bool:
return qp.Bool(rv.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return qp.Int(rv.Int())
i := rv.Int()
if i > limits.MaxInt53 || i < limits.MinInt53 {
panic(fmt.Sprintf("integer %d exceeds safe bounds", i))
}
return qp.Int(i)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return qp.Int(int64(rv.Uint()))
u := rv.Uint()
if u > limits.MaxInt53 {
panic(fmt.Sprintf("unsigned integer %d exceeds safe bounds", u))
}
return qp.Int(int64(u))
case reflect.Float32, reflect.Float64:
return qp.Float(rv.Float())
case reflect.String:

View File

@@ -8,6 +8,8 @@ import (
cidlink "github.com/ipld/go-ipld-prime/linking/cid"
"github.com/ipld/go-ipld-prime/printer"
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
)
func TestList(t *testing.T) {
@@ -214,7 +216,7 @@ func TestAny(t *testing.T) {
require.NoError(t, err)
require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm")))
v, err = Any(data["func"])
_, err = Any(data["func"])
require.Error(t, err)
}
@@ -254,6 +256,56 @@ func BenchmarkAny(b *testing.B) {
})
}
func TestAnyAssembleIntegerOverflow(t *testing.T) {
tests := []struct {
name string
input interface{}
shouldErr bool
}{
{
name: "valid int",
input: 42,
shouldErr: false,
},
{
name: "max safe int",
input: limits.MaxInt53,
shouldErr: false,
},
{
name: "min safe int",
input: limits.MinInt53,
shouldErr: false,
},
{
name: "overflow int",
input: int64(limits.MaxInt53 + 1),
shouldErr: true,
},
{
name: "underflow int",
input: int64(limits.MinInt53 - 1),
shouldErr: true,
},
{
name: "overflow uint",
input: uint64(limits.MaxInt53 + 1),
shouldErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Any(tt.input)
if tt.shouldErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)

View File

@@ -3,6 +3,7 @@ package policy
import (
"cmp"
"fmt"
"math"
"github.com/ipld/go-ipld-prime"
"github.com/ipld/go-ipld-prime/datamodel"
@@ -249,10 +250,22 @@ func matchStatement(cur Statement, node ipld.Node) (_ matchResult, leafMost Stat
panic(fmt.Errorf("unimplemented statement kind: %s", cur.Kind()))
}
// isOrdered compares two IPLD nodes and returns true if they satisfy the given ordering function.
// It supports comparison of integers and floats, returning false for:
// - Nodes of different or unsupported kinds
// - Integer values outside JavaScript's safe integer bounds (±2^53-1)
// - Non-finite floating point values (NaN or ±Inf)
//
// The satisfies parameter is a function that interprets the comparison result:
// - For ">" it returns true when order is 1
// - For ">=" it returns true when order is 0 or 1
// - For "<" it returns true when order is -1
// - For "<=" it returns true when order is -1 or 0
func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool {
if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int {
a := must.Int(actual)
b := must.Int(expected)
return satisfies(cmp.Compare(a, b))
}
@@ -265,6 +278,11 @@ func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) b
if err != nil {
panic(fmt.Errorf("extracting selector float: %w", err))
}
if math.IsInf(a, 0) || math.IsNaN(a) || math.IsInf(b, 0) || math.IsNaN(b) {
return false
}
return satisfies(cmp.Compare(a, b))
}

View File

@@ -6,6 +6,8 @@ import (
"regexp"
"strconv"
"strings"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
)
var (
@@ -56,6 +58,9 @@ func Parse(str string) (Selector, error) {
if err != nil {
return nil, newParseError("invalid index", str, col, tok)
}
if idx > limits.MaxInt53 || idx < limits.MinInt53 {
return nil, newParseError(fmt.Sprintf("index %d exceeds safe integer bounds", idx), str, col, tok)
}
sel = append(sel, segment{str: tok, optional: opt, index: idx})
// explicit field, ["abcd"]
@@ -77,6 +82,9 @@ func Parse(str string) (Selector, error) {
if err != nil {
return nil, newParseError("invalid slice index", str, col, tok)
}
if i > limits.MaxInt53 || i < limits.MinInt53 {
return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok)
}
rng[0] = i
}
if splt[1] == "" {
@@ -86,6 +94,9 @@ func Parse(str string) (Selector, error) {
if err != nil {
return nil, newParseError("invalid slice index", str, col, tok)
}
if i > limits.MaxInt53 || i < limits.MinInt53 {
return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok)
}
rng[1] = i
}
sel = append(sel, segment{str: tok, optional: opt, slice: rng[:]})

View File

@@ -1,10 +1,12 @@
package selector
import (
"fmt"
"math"
"testing"
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
)
func TestParse(t *testing.T) {
@@ -572,4 +574,23 @@ func TestParse(t *testing.T) {
_, err := Parse(".[foo]")
require.Error(t, err)
})
t.Run("integer overflow", func(t *testing.T) {
sel, err := Parse(fmt.Sprintf(".[%d]", limits.MaxInt53+1))
require.Error(t, err)
require.Nil(t, sel)
sel, err = Parse(fmt.Sprintf(".[%d]", limits.MinInt53-1))
require.Error(t, err)
require.Nil(t, sel)
// Test slice overflow
sel, err = Parse(fmt.Sprintf(".[%d:42]", limits.MaxInt53+1))
require.Error(t, err)
require.Nil(t, sel)
sel, err = Parse(fmt.Sprintf(".[1:%d]", limits.MaxInt53+1))
require.Error(t, err)
require.Nil(t, sel)
})
}

View File

@@ -266,19 +266,32 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) {
case slice[0] == math.MinInt:
start = 0
case slice[0] < 0:
// Check for potential overflow before adding
if -slice[0] > length {
start = 0
} else {
start = length + slice[0]
}
}
switch {
case slice[1] == math.MaxInt:
end = length
case slice[1] < 0:
// Check for potential overflow before adding
if -slice[1] > length {
end = 0
} else {
end = length + slice[1]
}
}
// backward iteration is not allowed, shortcut to an empty result
if start >= end {
start, end = 0, 0
return
}
// clamp out of bound
if start < 0 {
start = 0
@@ -286,11 +299,14 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) {
if start > length {
start = length
}
if end < 0 {
end = 0
}
if end > length {
end = length
}
return start, end
return
}
func kindString(n datamodel.Node) string {

View File

@@ -2,6 +2,7 @@ package selector
import (
"errors"
"math"
"strings"
"testing"
@@ -356,3 +357,57 @@ func FuzzParseAndSelect(f *testing.F) {
}
})
}
func TestResolveSliceIndices(t *testing.T) {
tests := []struct {
name string
slice []int64
length int64
wantStart int64
wantEnd int64
}{
{
name: "normal case",
slice: []int64{1, 3},
length: 5,
wantStart: 1,
wantEnd: 3,
},
{
name: "negative indices",
slice: []int64{-2, -1},
length: 5,
wantStart: 3,
wantEnd: 4,
},
{
name: "overflow protection negative start",
slice: []int64{math.MinInt64, 3},
length: 5,
wantStart: 0,
wantEnd: 3,
},
{
name: "overflow protection negative end",
slice: []int64{0, math.MinInt64},
length: 5,
wantStart: 0,
wantEnd: 0,
},
{
name: "max bounds",
slice: []int64{0, math.MaxInt64},
length: 5,
wantStart: 0,
wantEnd: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start, end := resolveSliceIndices(tt.slice, tt.length)
require.Equal(t, tt.wantStart, start)
require.Equal(t, tt.wantEnd, end)
})
}
}

View File

@@ -215,8 +215,15 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) {
tkn.meta = m.Meta
tkn.notBefore = parse.OptionalTimestamp(m.Nbf)
tkn.expiration = parse.OptionalTimestamp(m.Exp)
tkn.notBefore, err = parse.OptionalTimestamp(m.Nbf)
if err != nil {
return nil, fmt.Errorf("parse notBefore: %w", err)
}
tkn.expiration, err = parse.OptionalTimestamp(m.Exp)
if err != nil {
return nil, fmt.Errorf("parse expiration: %w", err)
}
if err := tkn.validate(); err != nil {
return nil, err

View File

@@ -1,9 +1,11 @@
package parse
import (
"fmt"
"time"
"github.com/ucan-wg/go-ucan/did"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
)
func OptionalDID(s *string) (did.DID, error) {
@@ -13,10 +15,15 @@ func OptionalDID(s *string) (did.DID, error) {
return did.Parse(*s)
}
func OptionalTimestamp(sec *int64) *time.Time {
func OptionalTimestamp(sec *int64) (*time.Time, error) {
if sec == nil {
return nil
return nil, nil
}
if *sec > limits.MaxInt53 || *sec < limits.MinInt53 {
return nil, fmt.Errorf("timestamp value %d exceeds safe integer bounds", *sec)
}
t := time.Unix(*sec, 0)
return &t
return &t, nil
}

View File

@@ -0,0 +1,64 @@
package parse
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/ucan-wg/go-ucan/pkg/policy/limits"
)
func TestOptionalTimestamp(t *testing.T) {
tests := []struct {
name string
input *int64
wantErr bool
}{
{
name: "nil timestamp",
input: nil,
wantErr: false,
},
{
name: "valid timestamp",
input: int64Ptr(1625097600),
wantErr: false,
},
{
name: "max safe integer",
input: int64Ptr(limits.MaxInt53),
wantErr: false,
},
{
name: "exceeds max safe integer",
input: int64Ptr(limits.MaxInt53 + 1),
wantErr: true,
},
{
name: "below min safe integer",
input: int64Ptr(limits.MinInt53 - 1),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := OptionalTimestamp(tt.input)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "exceeds safe integer bounds")
require.Nil(t, result)
} else {
require.NoError(t, err)
if tt.input == nil {
require.Nil(t, result)
} else {
require.NotNil(t, result)
}
}
})
}
}
func int64Ptr(i int64) *int64 {
return &i
}

View File

@@ -272,11 +272,22 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) {
tkn.nonce = m.Nonce
tkn.arguments = m.Args
if err := tkn.arguments.Validate(); err != nil {
return nil, fmt.Errorf("invalid arguments: %w", err)
}
tkn.proof = m.Prf
tkn.meta = m.Meta
tkn.expiration = parse.OptionalTimestamp(m.Exp)
tkn.invokedAt = parse.OptionalTimestamp(m.Iat)
tkn.expiration, err = parse.OptionalTimestamp(m.Exp)
if err != nil {
return nil, fmt.Errorf("parse expiration: %w", err)
}
tkn.invokedAt, err = parse.OptionalTimestamp(m.Iat)
if err != nil {
return nil, fmt.Errorf("parse invokedAt: %w", err)
}
tkn.cause = m.Cause