diff --git a/pkg/args/args.go b/pkg/args/args.go index 3cee3ce..a840c6e 100644 --- a/pkg/args/args.go +++ b/pkg/args/args.go @@ -16,6 +16,7 @@ import ( "github.com/ipld/go-ipld-prime/node/basicnode" "github.com/ipld/go-ipld-prime/printer" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" "github.com/ucan-wg/go-ucan/pkg/policy/literal" ) @@ -62,6 +63,10 @@ func (a *Args) Add(key string, val any) error { return err } + if err := limits.ValidateIntegerBoundsIPLD(node); err != nil { + return fmt.Errorf("value for key %q: %w", key, err) + } + a.Values[key] = node a.Keys = append(a.Keys, key) @@ -164,3 +169,14 @@ func (a *Args) Clone() *Args { } return res } + +// Validate checks that all values in the Args are valid according to UCAN specs +func (a *Args) Validate() error { + for key, value := range a.Values { + if err := limits.ValidateIntegerBoundsIPLD(value); err != nil { + return fmt.Errorf("value for key %q: %w", key, err) + } + } + + return nil +} diff --git a/pkg/args/args_test.go b/pkg/args/args_test.go index 2a44d0f..938151f 100644 --- a/pkg/args/args_test.go +++ b/pkg/args/args_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ucan-wg/go-ucan/pkg/args" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" "github.com/ucan-wg/go-ucan/pkg/policy/literal" ) @@ -185,6 +186,71 @@ func TestInclude(t *testing.T) { }, maps.Collect(a1.Iter())) } +func TestArgsIntegerBounds(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + key string + val int64 + wantErr string + }{ + { + name: "valid int", + key: "valid", + val: 42, + }, + { + name: "max safe integer", + key: "max", + val: limits.MaxInt53, + }, + { + name: "min safe integer", + key: "min", + val: limits.MinInt53, + }, + { + name: "exceeds max safe integer", + key: "tooBig", + val: limits.MaxInt53 + 1, + wantErr: "exceeds safe integer bounds", + }, + { + name: "below min safe integer", + key: "tooSmall", + val: limits.MinInt53 - 1, + wantErr: "exceeds safe integer bounds", + }, + { + name: "duplicate key", + key: "duplicate", + val: 42, + wantErr: "duplicate key", + }, + } + + a := args.New() + require.NoError(t, a.Add("duplicate", 1)) // tests duplicate key + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := a.Add(tt.key, tt.val) + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + val, err := a.GetNode(tt.key) + require.NoError(t, err) + i, err := val.AsInt() + require.NoError(t, err) + require.Equal(t, tt.val, i) + } + }) + } +} + const ( argsSchema = "type Args { String : Any }" argsName = "Args" diff --git a/pkg/policy/ipld.go b/pkg/policy/ipld.go index 9d52d4d..752dd96 100644 --- a/pkg/policy/ipld.go +++ b/pkg/policy/ipld.go @@ -9,10 +9,15 @@ import ( "github.com/ipld/go-ipld-prime/must" "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" "github.com/ucan-wg/go-ucan/pkg/policy/selector" ) func FromIPLD(node datamodel.Node) (Policy, error) { + if err := limits.ValidateIntegerBoundsIPLD(node); err != nil { + return nil, fmt.Errorf("policy contains integer values outside safe bounds: %w", err) + } + return statementsFromIPLD("/", node) } diff --git a/pkg/policy/limits/int.go b/pkg/policy/limits/int.go new file mode 100644 index 0000000..91dd135 --- /dev/null +++ b/pkg/policy/limits/int.go @@ -0,0 +1,49 @@ +package limits + +import ( + "fmt" + + "github.com/ipld/go-ipld-prime" + "github.com/ipld/go-ipld-prime/must" +) + +const ( + // MaxInt53 represents the maximum safe integer in JavaScript (2^53 - 1) + MaxInt53 = 9007199254740991 + // MinInt53 represents the minimum safe integer in JavaScript (-2^53 + 1) + MinInt53 = -9007199254740991 +) + +func ValidateIntegerBoundsIPLD(node ipld.Node) error { + switch node.Kind() { + case ipld.Kind_Int: + val := must.Int(node) + if val > MaxInt53 || val < MinInt53 { + return fmt.Errorf("integer value %d exceeds safe bounds", val) + } + case ipld.Kind_List: + it := node.ListIterator() + for !it.Done() { + _, v, err := it.Next() + if err != nil { + return err + } + if err := ValidateIntegerBoundsIPLD(v); err != nil { + return err + } + } + case ipld.Kind_Map: + it := node.MapIterator() + for !it.Done() { + _, v, err := it.Next() + if err != nil { + return err + } + if err := ValidateIntegerBoundsIPLD(v); err != nil { + return err + } + } + } + + return nil +} diff --git a/pkg/policy/limits/int_test.go b/pkg/policy/limits/int_test.go new file mode 100644 index 0000000..de3f21e --- /dev/null +++ b/pkg/policy/limits/int_test.go @@ -0,0 +1,82 @@ +package limits + +import ( + "testing" + + "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/fluent/qp" + "github.com/ipld/go-ipld-prime/node/basicnode" + "github.com/stretchr/testify/require" +) + +func TestValidateIntegerBoundsIPLD(t *testing.T) { + buildMap := func() datamodel.Node { + nb := basicnode.Prototype.Any.NewBuilder() + qp.Map(1, func(ma datamodel.MapAssembler) { + qp.MapEntry(ma, "foo", qp.Int(MaxInt53+1)) + })(nb) + return nb.Build() + } + + buildList := func() datamodel.Node { + nb := basicnode.Prototype.Any.NewBuilder() + qp.List(1, func(la datamodel.ListAssembler) { + qp.ListEntry(la, qp.Int(MinInt53-1)) + })(nb) + return nb.Build() + } + + tests := []struct { + name string + input datamodel.Node + wantErr bool + }{ + { + name: "valid int", + input: basicnode.NewInt(42), + wantErr: false, + }, + { + name: "max safe int", + input: basicnode.NewInt(MaxInt53), + wantErr: false, + }, + { + name: "min safe int", + input: basicnode.NewInt(MinInt53), + wantErr: false, + }, + { + name: "above MaxInt53", + input: basicnode.NewInt(MaxInt53 + 1), + wantErr: true, + }, + { + name: "below MinInt53", + input: basicnode.NewInt(MinInt53 - 1), + wantErr: true, + }, + { + name: "nested map with invalid int", + input: buildMap(), + wantErr: true, + }, + { + name: "nested list with invalid int", + input: buildList(), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateIntegerBoundsIPLD(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "exceeds safe bounds") + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/policy/literal/literal.go b/pkg/policy/literal/literal.go index 33b0904..333ade3 100644 --- a/pkg/policy/literal/literal.go +++ b/pkg/policy/literal/literal.go @@ -12,6 +12,8 @@ import ( "github.com/ipld/go-ipld-prime/fluent/qp" cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/node/basicnode" + + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) var Bool = basicnode.NewBool @@ -58,8 +60,6 @@ func List[T any](l []T) (ipld.Node, error) { // Any creates an IPLD node from any value // If possible, use another dedicated function for your type for performance. func Any(v any) (res ipld.Node, err error) { - // TODO: handle uint overflow below - // some fast path switch val := v.(type) { case bool: @@ -67,7 +67,11 @@ func Any(v any) (res ipld.Node, err error) { case string: return basicnode.NewString(val), nil case int: - return basicnode.NewInt(int64(val)), nil + i := int64(val) + if i > limits.MaxInt53 || i < limits.MinInt53 { + return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", i) + } + return basicnode.NewInt(i), nil case int8: return basicnode.NewInt(int64(val)), nil case int16: @@ -75,6 +79,9 @@ func Any(v any) (res ipld.Node, err error) { case int32: return basicnode.NewInt(int64(val)), nil case int64: + if val > limits.MaxInt53 || val < limits.MinInt53 { + return nil, fmt.Errorf("integer value %d exceeds safe integer bounds", val) + } return basicnode.NewInt(val), nil case uint: return basicnode.NewInt(int64(val)), nil @@ -85,6 +92,9 @@ func Any(v any) (res ipld.Node, err error) { case uint32: return basicnode.NewInt(int64(val)), nil case uint64: + if val > uint64(limits.MaxInt53) { + return nil, fmt.Errorf("unsigned integer value %d exceeds safe integer bounds", val) + } return basicnode.NewInt(int64(val)), nil case float32: return basicnode.NewFloat(float64(val)), nil @@ -168,9 +178,17 @@ func anyAssemble(val any) qp.Assemble { case reflect.Bool: return qp.Bool(rv.Bool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return qp.Int(rv.Int()) + i := rv.Int() + if i > limits.MaxInt53 || i < limits.MinInt53 { + panic(fmt.Sprintf("integer %d exceeds safe bounds", i)) + } + return qp.Int(i) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return qp.Int(int64(rv.Uint())) + u := rv.Uint() + if u > limits.MaxInt53 { + panic(fmt.Sprintf("unsigned integer %d exceeds safe bounds", u)) + } + return qp.Int(int64(u)) case reflect.Float32, reflect.Float64: return qp.Float(rv.Float()) case reflect.String: diff --git a/pkg/policy/literal/literal_test.go b/pkg/policy/literal/literal_test.go index 45d7c6c..9f7e1f2 100644 --- a/pkg/policy/literal/literal_test.go +++ b/pkg/policy/literal/literal_test.go @@ -8,6 +8,8 @@ import ( cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/ipld/go-ipld-prime/printer" "github.com/stretchr/testify/require" + + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) func TestList(t *testing.T) { @@ -214,7 +216,7 @@ func TestAny(t *testing.T) { require.NoError(t, err) require.True(t, asLink.(cidlink.Link).Equals(cid.MustParse("bafzbeigai3eoy2ccc7ybwjfz5r3rdxqrinwi4rwytly24tdbh6yk7zslrm"))) - v, err = Any(data["func"]) + _, err = Any(data["func"]) require.Error(t, err) } @@ -254,6 +256,56 @@ func BenchmarkAny(b *testing.B) { }) } +func TestAnyAssembleIntegerOverflow(t *testing.T) { + tests := []struct { + name string + input interface{} + shouldErr bool + }{ + { + name: "valid int", + input: 42, + shouldErr: false, + }, + { + name: "max safe int", + input: limits.MaxInt53, + shouldErr: false, + }, + { + name: "min safe int", + input: limits.MinInt53, + shouldErr: false, + }, + { + name: "overflow int", + input: int64(limits.MaxInt53 + 1), + shouldErr: true, + }, + { + name: "underflow int", + input: int64(limits.MinInt53 - 1), + shouldErr: true, + }, + { + name: "overflow uint", + input: uint64(limits.MaxInt53 + 1), + shouldErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := Any(tt.input) + if tt.shouldErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + func must[T any](t T, err error) T { if err != nil { panic(err) diff --git a/pkg/policy/match.go b/pkg/policy/match.go index 59316ed..648b877 100644 --- a/pkg/policy/match.go +++ b/pkg/policy/match.go @@ -3,6 +3,7 @@ package policy import ( "cmp" "fmt" + "math" "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/datamodel" @@ -249,10 +250,22 @@ func matchStatement(cur Statement, node ipld.Node) (_ matchResult, leafMost Stat panic(fmt.Errorf("unimplemented statement kind: %s", cur.Kind())) } +// isOrdered compares two IPLD nodes and returns true if they satisfy the given ordering function. +// It supports comparison of integers and floats, returning false for: +// - Nodes of different or unsupported kinds +// - Integer values outside JavaScript's safe integer bounds (±2^53-1) +// - Non-finite floating point values (NaN or ±Inf) +// +// The satisfies parameter is a function that interprets the comparison result: +// - For ">" it returns true when order is 1 +// - For ">=" it returns true when order is 0 or 1 +// - For "<" it returns true when order is -1 +// - For "<=" it returns true when order is -1 or 0 func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) bool) bool { if expected.Kind() == ipld.Kind_Int && actual.Kind() == ipld.Kind_Int { a := must.Int(actual) b := must.Int(expected) + return satisfies(cmp.Compare(a, b)) } @@ -265,6 +278,11 @@ func isOrdered(expected ipld.Node, actual ipld.Node, satisfies func(order int) b if err != nil { panic(fmt.Errorf("extracting selector float: %w", err)) } + + if math.IsInf(a, 0) || math.IsNaN(a) || math.IsInf(b, 0) || math.IsNaN(b) { + return false + } + return satisfies(cmp.Compare(a, b)) } diff --git a/pkg/policy/selector/parsing.go b/pkg/policy/selector/parsing.go index 05ab092..5ccc771 100644 --- a/pkg/policy/selector/parsing.go +++ b/pkg/policy/selector/parsing.go @@ -6,6 +6,8 @@ import ( "regexp" "strconv" "strings" + + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) var ( @@ -56,6 +58,9 @@ func Parse(str string) (Selector, error) { if err != nil { return nil, newParseError("invalid index", str, col, tok) } + if idx > limits.MaxInt53 || idx < limits.MinInt53 { + return nil, newParseError(fmt.Sprintf("index %d exceeds safe integer bounds", idx), str, col, tok) + } sel = append(sel, segment{str: tok, optional: opt, index: idx}) // explicit field, ["abcd"] @@ -77,6 +82,9 @@ func Parse(str string) (Selector, error) { if err != nil { return nil, newParseError("invalid slice index", str, col, tok) } + if i > limits.MaxInt53 || i < limits.MinInt53 { + return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok) + } rng[0] = i } if splt[1] == "" { @@ -86,6 +94,9 @@ func Parse(str string) (Selector, error) { if err != nil { return nil, newParseError("invalid slice index", str, col, tok) } + if i > limits.MaxInt53 || i < limits.MinInt53 { + return nil, newParseError(fmt.Sprintf("slice index %d exceeds safe integer bounds", i), str, col, tok) + } rng[1] = i } sel = append(sel, segment{str: tok, optional: opt, slice: rng[:]}) diff --git a/pkg/policy/selector/parsing_test.go b/pkg/policy/selector/parsing_test.go index b84ad52..2810227 100644 --- a/pkg/policy/selector/parsing_test.go +++ b/pkg/policy/selector/parsing_test.go @@ -1,10 +1,12 @@ package selector import ( + "fmt" "math" "testing" "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) func TestParse(t *testing.T) { @@ -572,4 +574,23 @@ func TestParse(t *testing.T) { _, err := Parse(".[foo]") require.Error(t, err) }) + + t.Run("integer overflow", func(t *testing.T) { + sel, err := Parse(fmt.Sprintf(".[%d]", limits.MaxInt53+1)) + require.Error(t, err) + require.Nil(t, sel) + + sel, err = Parse(fmt.Sprintf(".[%d]", limits.MinInt53-1)) + require.Error(t, err) + require.Nil(t, sel) + + // Test slice overflow + sel, err = Parse(fmt.Sprintf(".[%d:42]", limits.MaxInt53+1)) + require.Error(t, err) + require.Nil(t, sel) + + sel, err = Parse(fmt.Sprintf(".[1:%d]", limits.MaxInt53+1)) + require.Error(t, err) + require.Nil(t, sel) + }) } diff --git a/pkg/policy/selector/selector.go b/pkg/policy/selector/selector.go index 149078d..249cd44 100644 --- a/pkg/policy/selector/selector.go +++ b/pkg/policy/selector/selector.go @@ -266,19 +266,32 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) { case slice[0] == math.MinInt: start = 0 case slice[0] < 0: - start = length + slice[0] + // Check for potential overflow before adding + if -slice[0] > length { + start = 0 + } else { + start = length + slice[0] + } } + switch { case slice[1] == math.MaxInt: end = length case slice[1] < 0: - end = length + slice[1] + // Check for potential overflow before adding + if -slice[1] > length { + end = 0 + } else { + end = length + slice[1] + } } // backward iteration is not allowed, shortcut to an empty result if start >= end { start, end = 0, 0 + return } + // clamp out of bound if start < 0 { start = 0 @@ -286,11 +299,14 @@ func resolveSliceIndices(slice []int64, length int64) (start int64, end int64) { if start > length { start = length } + if end < 0 { + end = 0 + } if end > length { end = length } - return start, end + return } func kindString(n datamodel.Node) string { diff --git a/pkg/policy/selector/selector_test.go b/pkg/policy/selector/selector_test.go index 184b7b3..fdd18ec 100644 --- a/pkg/policy/selector/selector_test.go +++ b/pkg/policy/selector/selector_test.go @@ -2,6 +2,7 @@ package selector import ( "errors" + "math" "strings" "testing" @@ -356,3 +357,57 @@ func FuzzParseAndSelect(f *testing.F) { } }) } + +func TestResolveSliceIndices(t *testing.T) { + tests := []struct { + name string + slice []int64 + length int64 + wantStart int64 + wantEnd int64 + }{ + { + name: "normal case", + slice: []int64{1, 3}, + length: 5, + wantStart: 1, + wantEnd: 3, + }, + { + name: "negative indices", + slice: []int64{-2, -1}, + length: 5, + wantStart: 3, + wantEnd: 4, + }, + { + name: "overflow protection negative start", + slice: []int64{math.MinInt64, 3}, + length: 5, + wantStart: 0, + wantEnd: 3, + }, + { + name: "overflow protection negative end", + slice: []int64{0, math.MinInt64}, + length: 5, + wantStart: 0, + wantEnd: 0, + }, + { + name: "max bounds", + slice: []int64{0, math.MaxInt64}, + length: 5, + wantStart: 0, + wantEnd: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start, end := resolveSliceIndices(tt.slice, tt.length) + require.Equal(t, tt.wantStart, start) + require.Equal(t, tt.wantEnd, end) + }) + } +} diff --git a/token/delegation/delegation.go b/token/delegation/delegation.go index 959a40a..110b8e1 100644 --- a/token/delegation/delegation.go +++ b/token/delegation/delegation.go @@ -215,8 +215,15 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) { tkn.meta = m.Meta - tkn.notBefore = parse.OptionalTimestamp(m.Nbf) - tkn.expiration = parse.OptionalTimestamp(m.Exp) + tkn.notBefore, err = parse.OptionalTimestamp(m.Nbf) + if err != nil { + return nil, fmt.Errorf("parse notBefore: %w", err) + } + + tkn.expiration, err = parse.OptionalTimestamp(m.Exp) + if err != nil { + return nil, fmt.Errorf("parse expiration: %w", err) + } if err := tkn.validate(); err != nil { return nil, err diff --git a/token/internal/parse/parse.go b/token/internal/parse/parse.go index 147b308..27af240 100644 --- a/token/internal/parse/parse.go +++ b/token/internal/parse/parse.go @@ -1,9 +1,11 @@ package parse import ( + "fmt" "time" "github.com/ucan-wg/go-ucan/did" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" ) func OptionalDID(s *string) (did.DID, error) { @@ -13,10 +15,15 @@ func OptionalDID(s *string) (did.DID, error) { return did.Parse(*s) } -func OptionalTimestamp(sec *int64) *time.Time { +func OptionalTimestamp(sec *int64) (*time.Time, error) { if sec == nil { - return nil + return nil, nil } + + if *sec > limits.MaxInt53 || *sec < limits.MinInt53 { + return nil, fmt.Errorf("timestamp value %d exceeds safe integer bounds", *sec) + } + t := time.Unix(*sec, 0) - return &t + return &t, nil } diff --git a/token/internal/parse/parse_test.go b/token/internal/parse/parse_test.go new file mode 100644 index 0000000..9db6474 --- /dev/null +++ b/token/internal/parse/parse_test.go @@ -0,0 +1,64 @@ +package parse + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/ucan-wg/go-ucan/pkg/policy/limits" +) + +func TestOptionalTimestamp(t *testing.T) { + tests := []struct { + name string + input *int64 + wantErr bool + }{ + { + name: "nil timestamp", + input: nil, + wantErr: false, + }, + { + name: "valid timestamp", + input: int64Ptr(1625097600), + wantErr: false, + }, + { + name: "max safe integer", + input: int64Ptr(limits.MaxInt53), + wantErr: false, + }, + { + name: "exceeds max safe integer", + input: int64Ptr(limits.MaxInt53 + 1), + wantErr: true, + }, + { + name: "below min safe integer", + input: int64Ptr(limits.MinInt53 - 1), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := OptionalTimestamp(tt.input) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "exceeds safe integer bounds") + require.Nil(t, result) + } else { + require.NoError(t, err) + if tt.input == nil { + require.Nil(t, result) + } else { + require.NotNil(t, result) + } + } + }) + } +} + +func int64Ptr(i int64) *int64 { + return &i +} diff --git a/token/invocation/invocation.go b/token/invocation/invocation.go index ee85a94..4ab7b8b 100644 --- a/token/invocation/invocation.go +++ b/token/invocation/invocation.go @@ -272,11 +272,22 @@ func tokenFromModel(m tokenPayloadModel) (*Token, error) { tkn.nonce = m.Nonce tkn.arguments = m.Args + if err := tkn.arguments.Validate(); err != nil { + return nil, fmt.Errorf("invalid arguments: %w", err) + } + tkn.proof = m.Prf tkn.meta = m.Meta - tkn.expiration = parse.OptionalTimestamp(m.Exp) - tkn.invokedAt = parse.OptionalTimestamp(m.Iat) + tkn.expiration, err = parse.OptionalTimestamp(m.Exp) + if err != nil { + return nil, fmt.Errorf("parse expiration: %w", err) + } + + tkn.invokedAt, err = parse.OptionalTimestamp(m.Iat) + if err != nil { + return nil, fmt.Errorf("parse invokedAt: %w", err) + } tkn.cause = m.Cause