Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 29, 2024
1 parent 4f693e6 commit 7c1fea6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
6 changes: 4 additions & 2 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
reason="Requires capability >= 8.9 and torchao",
),
pytest.mark.parametrize("bias", (True, False))
pytest.mark.parametrize("bias", (True, False)),
),
)
def test_torchao_float8_linear(executor, device, dtype, bias):
Expand Down Expand Up @@ -287,7 +287,9 @@ def test_torchao_float8_linear(executor, device, dtype, bias):
jitted = executor.make_callable(fp8_model)

if bias and dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor:
with pytest.raises(AssertionError, match="unexpected a_role GemmInputRole.GRAD_OUTPUT and b_role GemmInputRole.GRAD_OUTPUT"):
with pytest.raises(
AssertionError, match="unexpected a_role GemmInputRole.GRAD_OUTPUT and b_role GemmInputRole.GRAD_OUTPUT"
):
jitted(x)
return
actual = jitted(x)
Expand Down
33 changes: 19 additions & 14 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,23 +247,26 @@ def __post_init__(self) -> None:
if len(self.computation_trace.bound_symbols) > 6:
maybe_unpack_C0_bsym = self.computation_trace.bound_symbols[4]
maybe_unpack_C1_bsym = self.computation_trace.bound_symbols[5]
is_backward_trace = maybe_unpack_C0_bsym.args and maybe_unpack_C1_bsym.args and (
maybe_unpack_C0_bsym.sym.id,
maybe_unpack_C1_bsym.sym.id,
maybe_unpack_C0_bsym.args[0].name,
maybe_unpack_C1_bsym.args[0].name,
) == (
prims.PrimIDs.UNPACK_SEQUENCE,
prims.PrimIDs.UNPACK_SEQUENCE,
"C0",
"C1",
is_backward_trace = (
maybe_unpack_C0_bsym.args
and maybe_unpack_C1_bsym.args
and (
maybe_unpack_C0_bsym.sym.id,
maybe_unpack_C1_bsym.sym.id,
maybe_unpack_C0_bsym.args[0].name,
maybe_unpack_C1_bsym.args[0].name,
)
== (
prims.PrimIDs.UNPACK_SEQUENCE,
prims.PrimIDs.UNPACK_SEQUENCE,
"C0",
"C1",
)
)
if is_backward_trace:
self.flat_trace_args, _ = tree_flatten((maybe_unpack_C0_bsym.output, maybe_unpack_C1_bsym.output))
if not is_backward_trace:
self.flat_trace_args, _ = tree_flatten(
(self.computation_trace.args, self.computation_trace.kwargs)
)
self.flat_trace_args, _ = tree_flatten((self.computation_trace.args, self.computation_trace.kwargs))
for arg in self.flat_trace_args:
if isinstance(arg, SubclassTensorProxy):
self.subclass_proxy_to_flatten.add(variableify(arg))
Expand Down Expand Up @@ -671,5 +674,7 @@ def flatten_tensor_subclasses(trace: TraceCtx) -> TraceCtx:
end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared (took {elapsed_time_millis} milliseconds)"))
computation_trace_with_subclass_tensor_proxy_output.set_provenance(
TraceProvenance("tensor subclasses desugared (took {elapsed_time_millis} milliseconds)")
)
return computation_trace_with_subclass_tensor_proxy_output

0 comments on commit 7c1fea6

Please sign in to comment.