From 04d528a7fec7c3db8626d7eef2f5dcf1ae36308c Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 28 Nov 2024 21:27:05 +0900 Subject: [PATCH] reduce return values by one Signed-off-by: Masaki Kozuki --- thunder/__init__.py | 2 +- thunder/executors/torch_autograd.py | 4 ++-- thunder/transforms/tensor_subclasses.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index acf59c8f7..01720d341 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -663,7 +663,7 @@ def get_computation_and_inputs(*args, **kwargs): # by split_forward_backward _tensor_subclass_transform_applied = True if not _tensor_subclass_transform_applied: - computation_trc, _ = flatten_tensor_subclasses(computation_trc) + computation_trc = flatten_tensor_subclasses(computation_trc) if backward_trc is None: from thunder.executors.passes import transform_for_execution as transform_for_execution_pass diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 3ed3b20ec..2f84cf3d9 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -155,7 +155,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # not any other container type. So we need to flatten the outputs of # the forward trace and inputs of the backward trace. fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) - fw_trace, fw_tensor_subclass_desugar = flatten_tensor_subclasses(fw_trace) + fw_trace = flatten_tensor_subclasses(fw_trace) fw_traces = [fw_trace] bw_traces = [bw_trace] @@ -247,7 +247,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if getattr(compile_data.fn, "use_fsdp", False): bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) - bw_trace, bw_tensor_subclass_desugar = flatten_tensor_subclasses(bw_trace) + bw_trace = flatten_tensor_subclasses(bw_trace) # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index e9ab709a3..22884db7d 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -585,7 +585,7 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) -def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, DesugarTensorSubclass]: +def flatten_tensor_subclasses(trace: TraceCtx) -> TraceCtx: """Flatten tensor subclasses in ``computation_trace``. Two things are happening inside of this function: @@ -605,23 +605,23 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, De the last few lines of the trace, right before return statement). Args: - computation_trace: + trace: Returns: TraceCtx: transformed trace that is free from tensor subclasses, every ``__torch_dispatch__`` behavior is spelled out. """ - desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=computation_trace) + desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=trace) updated_bsyms: list[BoundSymbol] = [] bsym: BoundSymbol - for bsym in computation_trace.bound_symbols: + for bsym in trace.bound_symbols: maybe_desugared_bsyms = desugar_tensor_subclass(bsym) updated_bsyms.extend(maybe_desugared_bsyms) if not desugar_tensor_subclass.subclass_proxy_to_flatten: - return computation_trace, None + return trace - computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace) + computation_trace_with_subclass_tensor_proxy_output = from_trace(trace) computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms) computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared")) - return computation_trace_with_subclass_tensor_proxy_output, desugar_tensor_subclass + return computation_trace_with_subclass_tensor_proxy_output