diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index ab104a7dc..7da62ac0e 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -305,12 +305,12 @@ def dispatch_modules(model, use_varlen_attn=False): dispatch_internlm2_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm2_rmsnorm_forward(model) - # replace_internlm2_rote(model) + replace_internlm2_rote(model) elif 'internlm' in model_name: dispatch_internlm_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm_rmsnorm_forward(model) - # replace_internlm_rote(model) + replace_internlm_rote(model) elif 'llama' in model_name: dispatch_llama_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: