Skip to content

Commit

Permalink
Number proxies is not a number (#286)
Browse files Browse the repository at this point in the history
Fixes #272.

Changes in this PR:

Removing the inheritance of Number kinds in NumberProxy derivatives. The motivation is to avoid mistakenly identify a NumberProxy as Numbers. Specifically, we are removing complex / int / float from the inheritance of ComplexProxy / IntegerProxy / FloatProxy.
changing existing checks from isinstance(x, Number) to isinstance(x, (Number, NumberProxy)).
Handling proper type promotion in NumberProxy's binary/unary operations.

Co-authored-by: Thomas Viehmann <[email protected]>
  • Loading branch information
jjsjann123 and t-vi authored May 16, 2024
1 parent 2fa97d6 commit 6deb2cc
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 185 deletions.
50 changes: 27 additions & 23 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import thunder.core.dtypes as dtypes
from thunder.core import utils
import thunder.core.prims as prims
from thunder.core.proxies import TensorProxy, pyval, pytype, proxy, AnyProxy, Proxy
from thunder.core.proxies import IntegerProxy, NumberProxy, TensorProxy, pyval, pytype, proxy, AnyProxy, Proxy
import thunder.core.devices as devices

# This file defines the operations in thunder.jit's "core" language.
Expand All @@ -29,6 +29,7 @@

__all__ = []

NumberLike = Number | NumberProxy
TensorLike = TensorProxy
DeviceLike = Union[str, devices.Device]

Expand Down Expand Up @@ -85,7 +86,7 @@ def check_instance(x: Any, types: tuple[type], /) -> None:

# Checks a number's value
@clangop()
def check_number_type_and_value(n: Number, value: Number, /) -> None:
def check_number_type_and_value(n: NumberLike, value: Number, /) -> None:
return prims.check_number_type_and_value(n, value)


Expand Down Expand Up @@ -140,7 +141,7 @@ def maybe_convert_to_dtype(a, dtype, *, enforce_safe_casting=False):
# Translates numbertypes to dtypes
if dtypes.is_numbertype(dtype):
dtype = dtypes.numbertype_to_dtype(dtype)
elif isinstance(a, Number):
elif isinstance(a, (Number, NumberProxy)):
# NOTE This allows conversions like (5, float32) -> 5., which is a little odd
dtype = utils.dtype_to_numbertype(dtype)
else:
Expand Down Expand Up @@ -176,7 +177,9 @@ def device_put(a, device):

# TODO Add type annotations
@clangop()
def arange(*, start: Number, step: Number, stop: Number, device: DeviceLike, dtype: dtypes.dtype | None = None):
def arange(
*, start: NumberLike, step: NumberLike, stop: NumberLike, device: DeviceLike, dtype: dtypes.dtype | None = None
):
# Validates inputs
# Checks that start, step, and stop are finite
# TODO Semantically an infinite step seems fine?
Expand All @@ -202,7 +205,8 @@ def arange(*, start: Number, step: Number, stop: Number, device: DeviceLike, dty
# (Optionally) infers dtype
# TODO Replace with default datatypes for integer and float
if dtype is None:
if all(tuple(isinstance(x, int) for x in (start, step, stop))):
# TODO: maybe something like a isIntegerType?
if all(tuple(isinstance(x, (int, IntegerProxy)) for x in (start, step, stop))):
dtype = dtypes.int64
else:
dtype = dtypes.float32
Expand Down Expand Up @@ -248,7 +252,7 @@ def convolution(

@clangop()
def full(
shape: Sequence[int], fill_value: Number, *, device: DeviceLike, dtype: None | dtypes.dtype = None
shape: Sequence[int], fill_value: NumberLike, *, device: DeviceLike, dtype: None | dtypes.dtype = None
) -> TensorLike:
# Infers dtype from the fill_value when not explicitly provided
if dtype is None:
Expand All @@ -261,12 +265,12 @@ def full(
@clangop()
def full_like(
a: TensorLike | Number,
fill_value: Number,
fill_value: NumberLike,
*,
device: DeviceLike | None = None,
dtype: dtypes.dtype | None = None,
) -> TensorLike:
if isinstance(a, Number):
if isinstance(a, (Number, NumberProxy)):
dtype = pytype(fill_value) if dtype is None else dtypes.dtype_to_numbertype(dtype)
utils.check(
device is None or devices.to_device(device).devicetype is devices.DeviceType.CPU,
Expand All @@ -290,8 +294,8 @@ def empty(shape: Sequence[int], *, device: DeviceLike, dtype: dtypes.dtype) -> T
@clangop()
def uniform(
shape: Sequence[int],
minval: Number = 0.0,
maxval: Number = 1.0,
minval: NumberLike = 0.0,
maxval: NumberLike = 1.0,
*,
device: DeviceLike,
dtype: dtypes.dtype,
Expand All @@ -305,8 +309,8 @@ def uniform(
@clangop()
def uniform_like(
a: TensorProxy,
minval: Number = 0.0,
maxval: Number = 1.0,
minval: NumberLike = 0.0,
maxval: NumberLike = 1.0,
*,
device: str | devices.Device | None = None,
dtype: dtypes.dtype | None = None,
Expand All @@ -320,8 +324,8 @@ def uniform_like(
@clangop()
def uniform_philox(
shape: Sequence[int],
minval: Number = 0.0,
maxval: Number = 1.0,
minval: NumberLike = 0.0,
maxval: NumberLike = 1.0,
*,
device: DeviceLike,
dtype: dtypes.dtype,
Expand Down Expand Up @@ -487,7 +491,7 @@ def _get_indexing_signature(key: Any) -> IndexingSignature:
return sig

# Numbers and slices are examples of basic indexing.
if isinstance(key, (Number, slice)):
if isinstance(key, (Number, NumberProxy, slice)):
sig.basic.append((None, None))
return sig

Expand Down Expand Up @@ -546,7 +550,7 @@ def __next__(self):
elif k is None:
sig.unsqueeze.append(i)
else:
if isinstance(k, (Number, slice)):
if isinstance(k, (Number, slice, NumberProxy)):
sig.basic.append((a_dim, i))
elif isinstance(k, (TensorLike, Sequence)):
sig.advanced.append((a_dim, i))
Expand All @@ -573,14 +577,14 @@ def _basic_indexing(a: TensorLike, /, key) -> TensorLike:
specified_slices = 0
ellipsis_idx = None

if key is None or isinstance(key, (Number, slice, EllipsisType)):
if key is None or isinstance(key, (Number, NumberProxy, slice, EllipsisType)):
key = (key,)

for idx, x in enumerate(key):
if x is Ellipsis:
utils.check(ellipsis_idx is None, lambda: f"Found two (or more) ellipses in key={key}")
ellipsis_idx = idx
elif isinstance(x, (Number, slice)):
elif isinstance(x, (NumberProxy, Number, slice)):
specified_slices += 1
elif x is None:
if ellipsis_idx is None:
Expand Down Expand Up @@ -660,7 +664,7 @@ def _convert_none(x):
start_indices.append(start)
end_indices.append(stop)
strides.append(step)
elif isinstance(x, Number):
elif isinstance(x, (Number, NumberProxy)):
# NOTE Numbers must be valid indices after canonicalization, unlike start and stop
x = utils.canonicalize_dim(l, x)
start_indices.append(x)
Expand Down Expand Up @@ -829,7 +833,7 @@ def getitem(a: TensorLike, /, key) -> TensorLike:
if key_idx is not None:
key_idx = key_idx if key_idx >= 0 else len(key) + key_idx
index = key[key_idx]
if isinstance(index, Sequence) and len(index) == 1 and isinstance(index[0], Number):
if isinstance(index, Sequence) and len(index) == 1 and isinstance(index[0], (Number, NumberProxy)):
start = index[0]
# Hande -1 to avoid empty slices
if start == -1:
Expand Down Expand Up @@ -1098,7 +1102,7 @@ def index_put(
# NOTE: the dimensions do not have to be specified in any order
@clangop()
def unsqueeze(a, /, dims: int | Sequence[int]) -> TensorProxy:
if isinstance(dims, Number):
if isinstance(dims, (Number, NumberProxy)):
dims = (dims,)

# Short-circuits if dims is empty
Expand Down Expand Up @@ -1338,7 +1342,7 @@ def ceil(a: TensorLike | Number) -> TensorLike | Number:
return _elementwise_unary_wrapper(
a,
prim=prims.ceil,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NUMBER_TO_INT,
)


Expand Down Expand Up @@ -1421,7 +1425,7 @@ def floor(a: TensorLike | Number) -> TensorLike | Number:
return _elementwise_unary_wrapper(
a,
prim=prims.floor,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.NUMBER_TO_INT,
)


Expand Down
12 changes: 6 additions & 6 deletions thunder/clang/langctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from thunder.core.langctxs import LanguageContext, register_langctx, Languages, resolve_language
from thunder.core.pytree import tree_flatten
from thunder.core.proxies import TensorProxy
from thunder.core.proxies import TensorProxy, NumberProxy

#
# Creates and registers the torch language context
Expand All @@ -30,19 +30,19 @@ def get_method(self, id: str, *args, **kwargs) -> Callable:
# not exist.
inps, _ = tree_flatten((args, kwargs))

has_tensor_input: bool = False
has_proxy_input: bool = False
for x in inps:
if isinstance(x, TensorProxy):
has_tensor_input = True
if isinstance(x, TensorProxy) or isinstance(x, NumberProxy):
has_proxy_input = True
break

if has_tensor_input:
if has_proxy_input:
method: None | Callable = _method_name_to_fn_map.get(id, None)
if method is None:
raise AttributeError(f"The {self.name} language context has no method {id}")
return method

# has_tensor_input is False
# has_proxy_input is False
# Defers to the primitive language context when there are no tensor inputs=
# (the primitive language context handles operations on numbers)
primsctx: LanguageContext = resolve_language(Languages.PRIMS)
Expand Down
3 changes: 2 additions & 1 deletion thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def check_types(xs: Sequence[Any], types: type | Sequence[type]):
def check_valid_length(length: int):
"""Validates that an object represents a valid dimension length."""

check_type(length, int)
# maybe we should skip the check for IntegerProxy in general
check_type(length, (int, NumberProxyInterface))
check(length >= 0, lambda: f"Found invalid length {length}!")


Expand Down
Loading

0 comments on commit 6deb2cc

Please sign in to comment.