diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 4a8e6b1d3..ce95b91bf 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -72,7 +72,7 @@ def make_compiled( region_trace = TraceCtx(None) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} - region_trace.names = set([a.name for a in region_trace.args]) + region_trace.names = {a.name for a in region_trace.args} with tracectx(region_trace): for a in sorted_unique_inputs: prims.unpack_trivial(a, name=a.name) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 545c3c05d..684c5dedb 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -91,15 +91,17 @@ def fn(a): jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) assert_close(jfn(a), fn(a)) + @pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") @requiresCUDA @pytest.mark.skipif(not device_supports_bf16(torch.device("cuda")), reason="bf16 is not supported") def test_litgpt_fabric_for_callable(): - from typing import Any, Callable, Optional, Tuple, Union, List, Dict + from typing import Any, Optional, Tuple, Union, List, Dict + from collections.abc import Callable from litgpt.model import Config, GPT import torch.nn as nn - def jit(fn: Callable, executors: List[str]) -> Any: + def jit(fn: Callable, executors: list[str]) -> Any: assert executors is not None return thunder.jit(fn, executors=executors)