diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py index f9664f13f..e3c5dc4d5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/graph_runner.py @@ -33,10 +33,17 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, dlinfer.graph.config.enable_graph_mode = True self.patch_kernels_custom_op() self.patch_kvcache_static_shape() - self.model = torch.compile(self.model, - fullgraph=True, - dynamic=True, - backend='atbgraph') + if hasattr(self.model, 'language_model'): + self.model.language_model = torch.compile( + self.model.language_model, + fullgraph=True, + dynamic=True, + backend='atbgraph') + else: + self.model = torch.compile(self.model, + fullgraph=True, + dynamic=True, + backend='atbgraph') def check_enable_graph(self): """check enable graph."""