Skip to content

Commit

Permalink
reduce return values by one
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Nov 28, 2024
1 parent 06ee30e commit 04d528a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def get_computation_and_inputs(*args, **kwargs):
# by split_forward_backward
_tensor_subclass_transform_applied = True
if not _tensor_subclass_transform_applied:
computation_trc, _ = flatten_tensor_subclasses(computation_trc)
computation_trc = flatten_tensor_subclasses(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
Expand Down
4 changes: 2 additions & 2 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# not any other container type. So we need to flatten the outputs of
# the forward trace and inputs of the backward trace.
fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
fw_trace, fw_tensor_subclass_desugar = flatten_tensor_subclasses(fw_trace)
fw_trace = flatten_tensor_subclasses(fw_trace)

fw_traces = [fw_trace]
bw_traces = [bw_trace]
Expand Down Expand Up @@ -247,7 +247,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
if getattr(compile_data.fn, "use_fsdp", False):
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)

bw_trace, bw_tensor_subclass_desugar = flatten_tensor_subclasses(bw_trace)
bw_trace = flatten_tensor_subclasses(bw_trace)

# Now we can run the optimization passes on the backward trace
# TODO Restore request for no rematerialization
Expand Down
14 changes: 7 additions & 7 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx)


def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, DesugarTensorSubclass]:
def flatten_tensor_subclasses(trace: TraceCtx) -> TraceCtx:
"""Flatten tensor subclasses in ``computation_trace``.
Two things are happening inside of this function:
Expand All @@ -605,23 +605,23 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, De
the last few lines of the trace, right before return statement).
Args:
computation_trace:
trace:
Returns:
TraceCtx: transformed trace that is free from tensor subclasses, every ``__torch_dispatch__``
behavior is spelled out.
"""
desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=computation_trace)
desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=trace)
updated_bsyms: list[BoundSymbol] = []
bsym: BoundSymbol
for bsym in computation_trace.bound_symbols:
for bsym in trace.bound_symbols:
maybe_desugared_bsyms = desugar_tensor_subclass(bsym)
updated_bsyms.extend(maybe_desugared_bsyms)

if not desugar_tensor_subclass.subclass_proxy_to_flatten:
return computation_trace, None
return trace

computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace)
computation_trace_with_subclass_tensor_proxy_output = from_trace(trace)
computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms)
computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared"))
return computation_trace_with_subclass_tensor_proxy_output, desugar_tensor_subclass
return computation_trace_with_subclass_tensor_proxy_output

0 comments on commit 04d528a

Please sign in to comment.