Skip to content

Commit

Permalink
Numberproxy slice (#1201)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored Oct 3, 2024
1 parent c4c3ce3 commit da1a441
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 5 deletions.
5 changes: 5 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,11 @@ def getitem(a: TensorLike, /, key) -> TensorLike:
lambda: f"{key=} tries to index more dimensions than {a.ndim=}",
)

# FIXME: This is a quick WAR to avoid accessing shape attribute of a without
# definition. This needs to be done properly somewhere else. See issue
# github.com/Lightning-AI/lightning-thunder/issues/1253
old_shape = prims.shape(a)

# We do not support mixing basic and advanced indexing together yet,
# but a very special case when there is a single advanced index which
# is a sequence of length 1.
Expand Down
7 changes: 7 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,13 @@ def translate(x: Any, *, name: str | None = None) -> Any:
return proxy(x, name=name)

if isinstance(x, Proxy):
# register proxy name used by NumberProxies in TensorProxy.shape
if isinstance(x, TensorProxy):
for s_p in filter(lambda s: isinstance(s, Proxy), x.shape):
# TODO need to avoid name conflict here, since s_p.name
# could have conflicted with something defined earlier in
# the trace.
get_tracectx().names.add(s_p.name)
if not rename_proxies:
get_tracectx().names.add(x.name)
return x
Expand Down
8 changes: 7 additions & 1 deletion thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ def _print_complex_number(c: complex) -> str:
return f"complex({real_str}, {imag_str})"


def _print_slice(s: slice) -> str:
val = (s.start, s.stop, s.step)

return f"slice({','.join(map(lambda x: x.name if isinstance(x, ProxyInterface) else str(x), val))})"


def print_number(n: Number) -> str:
if isinstance(n, complex):
return _print_complex_number(n)
Expand Down Expand Up @@ -389,7 +395,7 @@ def print_type(typ: type, /, *, with_quotes: bool = True) -> str:
int: lambda b: str(b),
float: _print_float_number,
complex: _print_complex_number,
slice: lambda slc: str(slc),
slice: _print_slice,
}


Expand Down
6 changes: 4 additions & 2 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def to_printable(
return x

if is_collection(x):
flat, spec = tree_flatten(x)
# specify namespace="" to avoid flattening dataclasses
flat, spec = tree_flatten(x, namespace="")
if flat and flat[0] is x:
raise RuntimeError(f"Don't know how to flatten object of {type(x)}")
printables = []
Expand Down Expand Up @@ -232,7 +233,8 @@ def prettyprint(
return m(f"{name}({call_repr_str})")

if is_collection(x):
flat, spec = tree_flatten(x)
# specify namespace="" to avoid flattening dataclasses
flat, spec = tree_flatten(x, namespace="")
printed = tuple(
prettyprint(x, with_type=False, literals_as_underscores=literals_as_underscores, _quote_markers=True)
for x in flat
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def _collectify(x: Any, *, name: str | None = None) -> Any:
return x
if baseutils.is_collection(x):
return CollectionProxy(x, name=name)
if isinstance(x, slice):
return CollectionProxy((x.start, x.stop, x.step), name=name)

return x

Expand Down
3 changes: 2 additions & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
from enum import auto, Enum
from numbers import Number
from typing import Type, Optional, Any, Tuple, List, Union
Expand Down Expand Up @@ -1962,7 +1963,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
distparallel_type = getattr(t, "distparallel_type", None)
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
if using_symbolic_values():
shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[history, wrap_const("shape").provenance])
shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance])
shape = tuple(
IntegerProxy(
None,
Expand Down
10 changes: 9 additions & 1 deletion thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
)


def tree_flatten(args, namespace=""):
optree.register_pytree_node(
slice,
lambda s: ([s.start, s.stop, s.step], None, None),
lambda _, children: slice(*children),
namespace=OPTREE_NAMESPACE,
)


def tree_flatten(args, namespace=OPTREE_NAMESPACE):
if (
type(args)
not in {
Expand Down

0 comments on commit da1a441

Please sign in to comment.