From 6fe16be2a145ee11b4dc2e218014f949b450d89f Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 25 Apr 2025 10:33:59 +0800 Subject: [PATCH] fix a bug in the flashinfer for deepseek2 --- lightllm/models/deepseek2/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 88db2ca37..65f5e0e54 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -71,7 +71,6 @@ def _init_inferstate_cls(self): self.infer_state_class = Deepseek2FlashAttentionStateInfo elif self.enable_flashinfer: self.infer_state_class = Deepseek2FlashInferStateInfo - self.flashinfer_extra_state = DeepSeek2FlashInferStateExtraInfo(self) def _init_some_value(self): super()._init_some_value() @@ -83,6 +82,8 @@ def _init_some_value(self): self.q_lora_rank = self.config["q_lora_rank"] self.kv_lora_rank = self.config["kv_lora_rank"] self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim + if self.enable_flashinfer: + self.flashinfer_extra_state = DeepSeek2FlashInferStateExtraInfo(self) def _init_custom(self): self._init_to_get_yarn_rotary()