From 7c1fea6b8b3e1c43806e6c3fcddfa191e8f3b590 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 14:24:33 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_tensor_subclass.py | 6 +++-- thunder/transforms/tensor_subclasses.py | 33 ++++++++++++++----------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index c4c6e0f46..a76d3ff4c 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -254,7 +254,7 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch. not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)), reason="Requires capability >= 8.9 and torchao", ), - pytest.mark.parametrize("bias", (True, False)) + pytest.mark.parametrize("bias", (True, False)), ), ) def test_torchao_float8_linear(executor, device, dtype, bias): @@ -287,7 +287,9 @@ def test_torchao_float8_linear(executor, device, dtype, bias): jitted = executor.make_callable(fp8_model) if bias and dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor: - with pytest.raises(AssertionError, match="unexpected a_role GemmInputRole.GRAD_OUTPUT and b_role GemmInputRole.GRAD_OUTPUT"): + with pytest.raises( + AssertionError, match="unexpected a_role GemmInputRole.GRAD_OUTPUT and b_role GemmInputRole.GRAD_OUTPUT" + ): jitted(x) return actual = jitted(x) diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index 70efd1a0b..1eff8694f 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -247,23 +247,26 @@ def __post_init__(self) -> None: if len(self.computation_trace.bound_symbols) > 6: maybe_unpack_C0_bsym = self.computation_trace.bound_symbols[4] maybe_unpack_C1_bsym = self.computation_trace.bound_symbols[5] - is_backward_trace = maybe_unpack_C0_bsym.args and maybe_unpack_C1_bsym.args and ( - maybe_unpack_C0_bsym.sym.id, - maybe_unpack_C1_bsym.sym.id, - maybe_unpack_C0_bsym.args[0].name, - maybe_unpack_C1_bsym.args[0].name, - ) == ( - prims.PrimIDs.UNPACK_SEQUENCE, - prims.PrimIDs.UNPACK_SEQUENCE, - "C0", - "C1", + is_backward_trace = ( + maybe_unpack_C0_bsym.args + and maybe_unpack_C1_bsym.args + and ( + maybe_unpack_C0_bsym.sym.id, + maybe_unpack_C1_bsym.sym.id, + maybe_unpack_C0_bsym.args[0].name, + maybe_unpack_C1_bsym.args[0].name, + ) + == ( + prims.PrimIDs.UNPACK_SEQUENCE, + prims.PrimIDs.UNPACK_SEQUENCE, + "C0", + "C1", + ) ) if is_backward_trace: self.flat_trace_args, _ = tree_flatten((maybe_unpack_C0_bsym.output, maybe_unpack_C1_bsym.output)) if not is_backward_trace: - self.flat_trace_args, _ = tree_flatten( - (self.computation_trace.args, self.computation_trace.kwargs) - ) + self.flat_trace_args, _ = tree_flatten((self.computation_trace.args, self.computation_trace.kwargs)) for arg in self.flat_trace_args: if isinstance(arg, SubclassTensorProxy): self.subclass_proxy_to_flatten.add(variableify(arg)) @@ -671,5 +674,7 @@ def flatten_tensor_subclasses(trace: TraceCtx) -> TraceCtx: end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared (took {elapsed_time_millis} milliseconds)")) + computation_trace_with_subclass_tensor_proxy_output.set_provenance( + TraceProvenance("tensor subclasses desugared (took {elapsed_time_millis} milliseconds)") + ) return computation_trace_with_subclass_tensor_proxy_output