From 13fea493834f146906574e4beac28f7027c205c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:20:40 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/executors/torch_compile.py | 9 +++++++-- thunder/tests/test_torch_compile_executor.py | 6 ++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index bfae5aa08..eecb0b725 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -72,7 +72,7 @@ def make_compiled( region_trace = TraceCtx(None) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} - region_trace.names = set([a.name for a in region_trace.args]) + region_trace.names = {a.name for a in region_trace.args} with tracectx(region_trace): for a in sorted_unique_inputs: prims.unpack_trivial(a, name=a.name) @@ -86,7 +86,12 @@ def make_compiled( if o is not None: region_trace.add_name(o.name) for sbsym in bsym.subsymbols: - list(map(lambda o: region_trace.add_name(o.name), filter(lambda o: o is not None and o.name not in region_trace.names, sbsym.flat_outs))) + list( + map( + lambda o: region_trace.add_name(o.name), + filter(lambda o: o is not None and o.name not in region_trace.names, sbsym.flat_outs), + ) + ) # maybe make this the default if no sig info is present? region_trace._siginfo = SigInfo("to_be_compiled") diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 545c3c05d..684c5dedb 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -91,15 +91,17 @@ def fn(a): jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) assert_close(jfn(a), fn(a)) + @pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") @requiresCUDA @pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported") def test_litgpt_fabric_for_callable(): - from typing import Any, Callable, Optional, Tuple, Union, List, Dict + from typing import Any, Optional, Tuple, Union, List, Dict + from collections.abc import Callable from litgpt.model import Config, GPT import torch.nn as nn - def jit(fn: Callable, executors: List[str]) -> Any: + def jit(fn: Callable, executors: list[str]) -> Any: assert executors is not None return thunder.jit(fn, executors=executors)