diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index bfae5aa08..eecb0b725 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -72,7 +72,7 @@ def make_compiled( region_trace = TraceCtx(None) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} - region_trace.names = set([a.name for a in region_trace.args]) + region_trace.names = {a.name for a in region_trace.args} with tracectx(region_trace): for a in sorted_unique_inputs: prims.unpack_trivial(a, name=a.name) @@ -86,7 +86,12 @@ def make_compiled( if o is not None: region_trace.add_name(o.name) for sbsym in bsym.subsymbols: - list(map(lambda o: region_trace.add_name(o.name), filter(lambda o: o is not None and o.name not in region_trace.names, sbsym.flat_outs))) + list( + map( + lambda o: region_trace.add_name(o.name), + filter(lambda o: o is not None and o.name not in region_trace.names, sbsym.flat_outs), + ) + ) # maybe make this the default if no sig info is present? region_trace._siginfo = SigInfo("to_be_compiled") diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 545c3c05d..684c5dedb 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -91,15 +91,17 @@ def fn(a): jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) assert_close(jfn(a), fn(a)) + @pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") @requiresCUDA @pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported") def test_litgpt_fabric_for_callable(): - from typing import Any, Callable, Optional, Tuple, Union, List, Dict + from typing import Any, Optional, Tuple, Union, List, Dict + from collections.abc import Callable from litgpt.model import Config, GPT import torch.nn as nn - def jit(fn: Callable, executors: List[str]) -> Any: + def jit(fn: Callable, executors: list[str]) -> Any: assert executors is not None return thunder.jit(fn, executors=executors)