From 5fcfced23f3b460f04ba533d9e13a4b220d74a96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Nov 2024 07:02:49 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_tensor_subclass.py | 6 +++-- thunder/transforms/tensor_subclasses.py | 33 ++++++++++++++----------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 85618ba3e..e63630a30 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -253,7 +253,7 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch. not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)), reason="Requires capability >= 8.9 and torchao", ), - pytest.mark.parametrize("bias", (True, False)) + pytest.mark.parametrize("bias", (True, False)), ), ) def test_torchao_float8_linear(executor, device, dtype, bias): @@ -294,7 +294,9 @@ def test_torchao_float8_linear(executor, device, dtype, bias): torch.testing.assert_close(actual, expected) return - if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor): + if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or ( + not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor + ): pytest.xfail("numerical error") torch.testing.assert_close(actual, expected) diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index 5eb2e2411..677dbc60e 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -254,23 +254,26 @@ def __post_init__(self) -> None: if len(self.computation_trace.bound_symbols) > 6: maybe_unpack_C0_bsym = self.computation_trace.bound_symbols[4] maybe_unpack_C1_bsym = self.computation_trace.bound_symbols[5] - is_backward_trace = maybe_unpack_C0_bsym.args and maybe_unpack_C1_bsym.args and ( - maybe_unpack_C0_bsym.sym.id, - maybe_unpack_C1_bsym.sym.id, - getattr(maybe_unpack_C0_bsym.args[0], "name", ""), - getattr(maybe_unpack_C1_bsym.args[0], "name", ""), - ) == ( - prims.PrimIDs.UNPACK_SEQUENCE, - prims.PrimIDs.UNPACK_SEQUENCE, - "C0", - "C1", + is_backward_trace = ( + maybe_unpack_C0_bsym.args + and maybe_unpack_C1_bsym.args + and ( + maybe_unpack_C0_bsym.sym.id, + maybe_unpack_C1_bsym.sym.id, + getattr(maybe_unpack_C0_bsym.args[0], "name", ""), + getattr(maybe_unpack_C1_bsym.args[0], "name", ""), + ) + == ( + prims.PrimIDs.UNPACK_SEQUENCE, + prims.PrimIDs.UNPACK_SEQUENCE, + "C0", + "C1", + ) ) if is_backward_trace: self.flat_trace_args, _ = tree_flatten((maybe_unpack_C0_bsym.output, maybe_unpack_C1_bsym.output)) if not is_backward_trace: - self.flat_trace_args, _ = tree_flatten( - (self.computation_trace.args, self.computation_trace.kwargs) - ) + self.flat_trace_args, _ = tree_flatten((self.computation_trace.args, self.computation_trace.kwargs)) for arg in self.flat_trace_args: if isinstance(arg, SubclassTensorProxy): self.subclass_proxy_to_flatten.add(variableify(arg)) @@ -679,6 +682,8 @@ def flatten_tensor_subclasses(trace: TraceCtx) -> TraceCtx: computation_trace_with_subclass_tensor_proxy_output = from_trace(trace) computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms) - computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance(f"tensor subclasses desugared (took {elapsed_time_millis} milliseconds)")) + computation_trace_with_subclass_tensor_proxy_output.set_provenance( + TraceProvenance(f"tensor subclasses desugared (took {elapsed_time_millis} milliseconds)") + ) warn_tensor_subclass_support() return computation_trace_with_subclass_tensor_proxy_output