Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add UInt debug_assert(value >= 0) for Int implicit conversion #3753

Open
wants to merge 11 commits into
base: nightly
Choose a base branch
from
29 changes: 13 additions & 16 deletions stdlib/src/builtin/int.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,6 @@ fn int(value: String, base: Int = 10) raises -> Int:
return atol(value, base)


fn int(value: UInt) -> Int:
"""Get the Int representation of the value.

Args:
value: The object to get the integral representation of.

Returns:
The integral representation of the value.
"""
return value.value


# ===----------------------------------------------------------------------=== #
# Int
# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -1020,13 +1008,22 @@ struct Int(

@always_inline("nodebug")
fn __int__(self) -> Int:
"""Gets the integral value (this is an identity function for Int).
"""Gets the integral value.

Returns:
The value as an integer.
"""
return self

@always_inline("nodebug")
fn __uint__(self) -> UInt:
"""Gets the unsigned integral value.

Returns:
The value as an unsigned integer.
"""
return self.value

@always_inline("nodebug")
fn __abs__(self) -> Self:
"""Return the absolute value of the Int value.
Expand Down Expand Up @@ -1109,9 +1106,9 @@ struct Int(
"""Hash the int using builtin hash.

Returns:
A 64-bit hash value. This value is _not_ suitable for cryptographic
uses. Its intended usage is for data structures. See the `hash`
builtin documentation for more details.
A 64-bit hash value. This value is **_not_** suitable for
cryptographic uses. Its intended usage is for data structures. See
the `hash` builtin documentation for more details.
"""
# TODO(MOCO-636): switch to DType.index
return _hash_simd(Scalar[DType.int64](self))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/builtin/range.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ struct _UIntStridedRange(UIntSized, _UIntStridedIterable):
step != 0, "range() arg 3 (the step size) must not be zero"
)
debug_assert(
step != UInt(Int(-1)),
step != uint(-1),
(
"range() arg 3 (the step size) cannot be -1. Reverse range is"
" not supported yet for UInt ranges."
Expand Down
13 changes: 13 additions & 0 deletions stdlib/src/builtin/simd.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,19 @@ struct SIMD[type: DType, size: Int](
_type = __mlir_type.`!pop.scalar<index>`
](rebind[Scalar[type]](self).value)

@always_inline("nodebug")
fn __uint__(self) -> UInt:
"""Casts to the value to a UInt. If there is a fractional component,
then the fractional part is truncated.

Constraints:
The size of the SIMD vector must be 1.

Returns:
The value as an integer.
"""
return uint(int(self))

@always_inline("nodebug")
fn __mlir_index__(self) -> __mlir_type.index:
"""Convert to index.
Expand Down
112 changes: 110 additions & 2 deletions stdlib/src/builtin/uint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,86 @@ from hashlib.hash import _hash_simd
from hashlib._hasher import _HashableWithHasher, _Hasher


# ===----------------------------------------------------------------------=== #
# UIntable
# ===----------------------------------------------------------------------=== #


trait UIntable:
"""The `UIntable` trait describes a type that can be converted to a `UInt`.
"""

fn __uint__(self) -> UInt:
"""Get the unsigned integral representation of the value.

Returns:
The unsigned integral representation of the value.
"""
...


trait UIntableRaising:
"""The `UIntableRaising` trait describes a type can be converted to a UInt,
but the conversion might raise an error.
"""

fn __uint__(self) raises -> UInt:
"""Get the unsigned integral representation of the value.

Returns:
The unsigned integral representation of the type.

Raises:
If the type does not have an unsigned integral representation.
"""
...


# ===----------------------------------------------------------------------=== #
# uint
# ===----------------------------------------------------------------------=== #


@always_inline
fn uint[T: UIntable](value: T) -> UInt:
"""Get the `UInt` representation of the value.

Parameters:
T: The UIntable type.

Args:
value: The object to get the integral representation of.

Returns:
The unsigned integral representation of the value.
"""
return value.__uint__()


@always_inline
fn uint[T: UIntableRaising](value: T) raises -> UInt:
"""Get the `UInt` representation of the value.

Parameters:
T: The UIntable type.

Args:
value: The object to get the integral representation of.

Returns:
The unsigned integral representation of the value.

Raises:
If the type does not have an integral representation.
"""
return value.__uint__()


# ===----------------------------------------------------------------------=== #
# UInt
# ===----------------------------------------------------------------------=== #


@lldb_formatter_wrapping_type
@value
@register_passable("trivial")
Expand Down Expand Up @@ -82,20 +162,30 @@ struct UInt(IntLike, _HashableWithHasher):

@always_inline("nodebug")
fn __init__(out self, value: Int):
"""Construct UInt from the given index value.
"""Construct `UInt` from the given `Int` value.

Args:
value: The init value.
"""
debug_assert(
value >= 0,
"Constructing `UInt` from negative `Int` is discouraged, use the",
" `uint` builtin function (e.g. `uint(-1)`) if you are sure.",
)
self.value = value.value

@always_inline("nodebug")
fn __init__(out self, value: IntLiteral):
"""Construct UInt from the given IntLiteral value.
"""Construct `UInt` from the given `IntLiteral` value.

Args:
value: The init value.
"""
debug_assert(
value >= 0,
"Constructing `UInt` from negative `Int` is discouraged, use the",
" `uint` builtin function (e.g. `uint(-1)`) if you are sure.",
)
self = value.__uint__()

@always_inline("nodebug")
Expand All @@ -107,6 +197,24 @@ struct UInt(IntLike, _HashableWithHasher):
"""
return self.value

@always_inline("nodebug")
fn __int__(self) -> Int:
"""Gets the integral value.

Returns:
The value as an integer.
"""
return self.value

@always_inline("nodebug")
fn __uint__(self) -> UInt:
"""Gets the unsigned integral value.

Returns:
The value as an unsigned integer.
"""
return self

@no_inline
fn __str__(self) -> String:
"""Convert this UInt to a string.
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/hashlib/hash.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ fn _hash_simd[type: DType, size: Int](data: SIMD[type, size]) -> UInt:
bitcast[int_type, 1](hash_data[i]).cast[DType.uint64](),
)

return int(final_data)
return uint(final_data)


fn hash(bytes: UnsafePointer[UInt8], n: Int) -> UInt:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/prelude/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ from builtin.type_aliases import (
OriginSet,
Origin,
)
from builtin.uint import UInt
from builtin.uint import UInt, uint
from builtin.value import (
Movable,
Copyable,
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/utils/static_tuple.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ struct StaticTuple[element_type: AnyTrivialRegType, size: Int](Sized):
The value at the specified position.
"""
debug_assert(
int(idx.__mlir_index__()) < size, "index must be within bounds"
Int(idx.__mlir_index__()) < size, "index must be within bounds"
)
# Copy the array so we can get its address, because we can't take the
# address of 'self' in a non-mutating method.
Expand All @@ -230,7 +230,7 @@ struct StaticTuple[element_type: AnyTrivialRegType, size: Int](Sized):
val: The value to store.
"""
debug_assert(
int(idx.__mlir_index__()) < size, "index must be within bounds"
Int(idx.__mlir_index__()) < size, "index must be within bounds"
)
var tmp = self
var ptr = __mlir_op.`pop.array.gep`(
Expand Down
50 changes: 25 additions & 25 deletions stdlib/test/builtin/test_uint.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,18 @@ def test_properties():

def test_add():
assert_equal(UInt.__add__(UInt(3), UInt(3)), UInt(6))
assert_equal(UInt.__add__(UInt(Int(-2)), UInt(3)), UInt(1))
assert_equal(UInt.__add__(UInt(2), UInt(Int(-3))), UInt(Int(-1)))
assert_equal(UInt.__add__(UInt(5), UInt(Int(-5))), UInt(0))
assert_equal(UInt.__add__(UInt(Int(-5)), UInt(Int(-4))), UInt(Int(-9)))
assert_equal(UInt.__add__(uint(-2), UInt(3)), UInt(1))
assert_equal(UInt.__add__(UInt(2), uint(-3)), uint(-1))
assert_equal(UInt.__add__(UInt(5), uint(-5)), UInt(0))
assert_equal(UInt.__add__(uint(-5), uint(-4)), uint(-9))


def test_sub():
assert_equal(UInt.__sub__(UInt(3), UInt(3)), UInt(0))
assert_equal(UInt.__sub__(UInt(Int(-2)), UInt(3)), UInt(Int(-5)))
assert_equal(UInt.__sub__(UInt(2), UInt(Int(-3))), UInt(5))
assert_equal(UInt.__sub__(uint(-2), UInt(3)), uint(-5))
assert_equal(UInt.__sub__(UInt(2), uint(-3)), UInt(5))
assert_equal(UInt.__sub__(UInt(5), UInt(4)), UInt(1))
assert_equal(UInt.__sub__(UInt(4), UInt(5)), UInt(Int(-1)))
assert_equal(UInt.__sub__(UInt(4), UInt(5)), uint(-1))


def test_div():
Expand All @@ -128,23 +128,23 @@ def test_pow():
def test_ceil():
assert_equal(UInt.__ceil__(UInt(5)), UInt(5))
assert_equal(UInt.__ceil__(UInt(0)), UInt(0))
assert_equal(UInt.__ceil__(UInt(Int(-5))), UInt(Int(-5)))
assert_equal(UInt.__ceil__(uint(-5)), uint(-5))


def test_floor():
assert_equal(UInt.__floor__(UInt(5)), UInt(5))
assert_equal(UInt.__floor__(UInt(0)), UInt(0))
assert_equal(UInt.__floor__(UInt(Int(-5))), UInt(Int(-5)))
assert_equal(UInt.__floor__(uint(-5)), uint(-5))


def test_round():
assert_equal(UInt.__round__(UInt(5)), UInt(5))
assert_equal(UInt.__round__(UInt(0)), UInt(0))
assert_equal(UInt.__round__(UInt(Int(-5))), UInt(Int(-5)))
assert_equal(UInt.__round__(uint(-5)), uint(-5))
assert_equal(UInt.__round__(UInt(5), UInt(1)), UInt(5))
assert_equal(UInt.__round__(UInt(0), UInt(1)), UInt(0))
assert_equal(UInt.__round__(UInt(Int(-5)), UInt(1)), UInt(Int(-5)))
assert_equal(UInt.__round__(UInt(100), UInt(Int(-2))), UInt(100))
assert_equal(UInt.__round__(uint(-5), UInt(1)), uint(-5))
assert_equal(UInt.__round__(UInt(100), uint(-2)), UInt(100))


def test_trunc():
Expand All @@ -156,20 +156,20 @@ def test_floordiv():
assert_equal(UInt(1), UInt.__floordiv__(UInt(2), UInt(2)))
assert_equal(UInt(0), UInt.__floordiv__(UInt(2), UInt(3)))
assert_equal(UInt(2), UInt.__floordiv__(UInt(100), UInt(50)))
assert_equal(UInt(0), UInt.__floordiv__(UInt(2), UInt(Int(-2))))
assert_equal(UInt(0), UInt.__floordiv__(UInt(99), UInt(Int(-2))))
assert_equal(UInt(0), UInt.__floordiv__(UInt(2), uint(-2)))
assert_equal(UInt(0), UInt.__floordiv__(UInt(99), uint(-2)))


def test_mod():
assert_equal(UInt(0), UInt.__mod__(UInt(99), UInt(1)))
assert_equal(UInt(0), UInt.__mod__(UInt(99), UInt(3)))
assert_equal(UInt(99), UInt.__mod__(UInt(99), UInt(Int(-2))))
assert_equal(UInt(99), UInt.__mod__(UInt(99), uint(-2)))
assert_equal(UInt(3), UInt.__mod__(UInt(99), UInt(8)))
assert_equal(UInt(99), UInt.__mod__(UInt(99), UInt(Int(-8))))
assert_equal(UInt(2), UInt.__mod__(UInt(2), UInt(Int(-1))))
assert_equal(UInt(2), UInt.__mod__(UInt(2), UInt(Int(-2))))
assert_equal(UInt(3), UInt.__mod__(UInt(3), UInt(Int(-2))))
assert_equal(UInt(1), UInt.__mod__(UInt(Int(-3)), UInt(2)))
assert_equal(UInt(99), UInt.__mod__(UInt(99), uint(-8)))
assert_equal(UInt(2), UInt.__mod__(UInt(2), uint(-1)))
assert_equal(UInt(2), UInt.__mod__(UInt(2), uint(-2)))
assert_equal(UInt(3), UInt.__mod__(UInt(3), uint(-2)))
assert_equal(UInt(1), UInt.__mod__(uint(-3), UInt(2)))


def test_divmod():
Expand All @@ -189,25 +189,25 @@ def test_divmod():


def test_abs():
assert_equal(UInt(Int(-5)).__abs__(), UInt(18446744073709551611))
assert_equal(uint(-5).__abs__(), UInt(18446744073709551611))
assert_equal(UInt(2).__abs__(), UInt(2))
assert_equal(UInt(0).__abs__(), UInt(0))


def test_string_conversion():
assert_equal(UInt(3).__str__(), "3")
assert_equal(UInt(Int(-3)).__str__(), "18446744073709551613")
assert_equal(uint(-3).__str__(), "18446744073709551613")
assert_equal(UInt(0).__str__(), "0")
assert_equal(UInt(100).__str__(), "100")
assert_equal(UInt(Int(-100)).__str__(), "18446744073709551516")
assert_equal(uint(-100).__str__(), "18446744073709551516")


def test_int_representation():
assert_equal(UInt(3).__repr__(), "UInt(3)")
assert_equal(UInt(Int(-3)).__repr__(), "UInt(18446744073709551613)")
assert_equal(uint(-3).__repr__(), "UInt(18446744073709551613)")
assert_equal(UInt(0).__repr__(), "UInt(0)")
assert_equal(UInt(100).__repr__(), "UInt(100)")
assert_equal(UInt(Int(-100)).__repr__(), "UInt(18446744073709551516)")
assert_equal(uint(-100).__repr__(), "UInt(18446744073709551516)")


def test_indexer():
Expand Down
2 changes: 1 addition & 1 deletion stdlib/test/collections/test_dict.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ struct DummyKey(KeyElement):
self = other

fn __hash__(self) -> UInt:
return self.value
return uint(self.value)

fn __eq__(self, other: DummyKey) -> Bool:
return self.value == other.value
Expand Down