From 94b0332f86f5036d545aac2dce611a7ec77c83a6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 18 Jul 2024 16:31:50 -0400 Subject: [PATCH] [Model] Support Mistral-Nemo (#6548) --- vllm/model_executor/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c434e54cf743..08f449f20305a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -89,6 +89,7 @@ class LlamaAttention(nn.Module): def __init__( self, + config: LlamaConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -115,7 +116,9 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -189,6 +192,7 @@ def __init__( attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) self.self_attn = LlamaAttention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads",