diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 65f0199093..74b1456149 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -975,6 +975,11 @@ def getitem(a: TensorLike, /, key) -> TensorLike: lambda: f"{key=} tries to index more dimensions than {a.ndim=}", ) + # FIXME: This is a quick WAR to avoid accessing shape attribute of a without + # definition. This needs to be done properly somewhere else. See issue + # github.com/Lightning-AI/lightning-thunder/issues/1253 + old_shape = prims.shape(a) + # We do not support mixing basic and advanced indexing together yet, # but a very special case when there is a single advanced index which # is a sequence of length 1. diff --git a/thunder/common.py b/thunder/common.py index 168c415fbb..92ebeac8e5 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -291,6 +291,13 @@ def translate(x: Any, *, name: str | None = None) -> Any: return proxy(x, name=name) if isinstance(x, Proxy): + # register proxy name used by NumberProxies in TensorProxy.shape + if isinstance(x, TensorProxy): + for s_p in filter(lambda s: isinstance(s, Proxy), x.shape): + # TODO need to avoid name conflict here, since s_p.name + # could have conflicted with something defined earlier in + # the trace. + get_tracectx().names.add(s_p.name) if not rename_proxies: get_tracectx().names.add(x.name) return x diff --git a/thunder/core/baseutils.py b/thunder/core/baseutils.py index 120e96e1a1..55284b3b12 100644 --- a/thunder/core/baseutils.py +++ b/thunder/core/baseutils.py @@ -281,6 +281,12 @@ def _print_complex_number(c: complex) -> str: return f"complex({real_str}, {imag_str})" +def _print_slice(s: slice) -> str: + val = (s.start, s.stop, s.step) + + return f"slice({','.join(map(lambda x: x.name if isinstance(x, ProxyInterface) else str(x), val))})" + + def print_number(n: Number) -> str: if isinstance(n, complex): return _print_complex_number(n) @@ -389,7 +395,7 @@ def print_type(typ: type, /, *, with_quotes: bool = True) -> str: int: lambda b: str(b), float: _print_float_number, complex: _print_complex_number, - slice: lambda slc: str(slc), + slice: _print_slice, } diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 66e674ab77..6afbeca7ac 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -140,7 +140,8 @@ def to_printable( return x if is_collection(x): - flat, spec = tree_flatten(x) + # specify namespace="" to avoid flattening dataclasses + flat, spec = tree_flatten(x, namespace="") if flat and flat[0] is x: raise RuntimeError(f"Don't know how to flatten object of {type(x)}") printables = [] @@ -232,7 +233,8 @@ def prettyprint( return m(f"{name}({call_repr_str})") if is_collection(x): - flat, spec = tree_flatten(x) + # specify namespace="" to avoid flattening dataclasses + flat, spec = tree_flatten(x, namespace="") printed = tuple( prettyprint(x, with_type=False, literals_as_underscores=literals_as_underscores, _quote_markers=True) for x in flat diff --git a/thunder/core/prims.py b/thunder/core/prims.py index d847b30936..588043332c 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -470,6 +470,8 @@ def _collectify(x: Any, *, name: str | None = None) -> Any: return x if baseutils.is_collection(x): return CollectionProxy(x, name=name) + if isinstance(x, slice): + return CollectionProxy((x.start, x.stop, x.step), name=name) return x diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index b7978565b3..ae7744dda3 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from enum import auto, Enum from numbers import Number from typing import Type, Optional, Any, Tuple, List, Union @@ -1962,7 +1963,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = distparallel_type = getattr(t, "distparallel_type", None) _thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None) if using_symbolic_values(): - shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[history, wrap_const("shape").provenance]) + shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance]) shape = tuple( IntegerProxy( None, diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 20f09debaa..97248c75bf 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -21,7 +21,15 @@ ) -def tree_flatten(args, namespace=""): +optree.register_pytree_node( + slice, + lambda s: ([s.start, s.stop, s.step], None, None), + lambda _, children: slice(*children), + namespace=OPTREE_NAMESPACE, +) + + +def tree_flatten(args, namespace=OPTREE_NAMESPACE): if ( type(args) not in {