diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index a2f1af9a3a..11237e4c84 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -1709,6 +1709,14 @@ def trunc(a: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) register_supported(PrimIDs.TRUNC, trunc, _elementwise_unary_check) +def clone(a: TensorProxy, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any: + nva = getnv(a, fd, lc_to_nv_map) + + return fd.ops.set(nva) + + +register_supported(PrimIDs.CLONE, clone, _elementwise_unary_check) + # # Elementwise binary operations # diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 57aea7f871..788ea29992 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -510,6 +510,7 @@ def test_hf_for_nemo(model_id): def test_hf_llama(): from transformers.models.llama import LlamaForCausalLM, LlamaConfig from transformers.models.llama.modeling_llama import logger as llama_logger + from thunder.examine import get_fusion_symbols import logging # transformers logs a cache deprecation warning @@ -548,9 +549,8 @@ def test_hf_llama(): expected2 = model(past_key_values=res["past_key_values"], **args2) assert_close(res2, expected2, rtol=1e-1, atol=1e-1) - top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols} # changes this to fewer as needed, the goal is to not have too many fusions - assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7 + assert len(get_fusion_symbols(thunder.last_traces(jm)[-1])) == 7 @requiresCUDA