diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 9b2c9255d7354..ee6d36f69506f 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -26,7 +26,6 @@ import torch.utils.checkpoint from torch import nn from transformers import CohereConfig -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from vllm.attention import Attention, AttentionMetadata from vllm.model_executor.layers.activation import SiluAndMul @@ -46,8 +45,6 @@ hf_model_weights_iterator) from vllm.sequence import SamplerOutput -KVCache = Tuple[torch.Tensor, torch.Tensor] - class LayerNorm(nn.Module): @@ -70,9 +67,6 @@ def forward(self, hidden_states, residuals=None): return hidden_states.to(input_dtype), residuals -ALL_LAYERNORM_LAYERS.append(LayerNorm) - - # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): @@ -137,7 +131,6 @@ def __init__( self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) - self.is_causal = True self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, @@ -171,7 +164,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) @@ -200,7 +193,7 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: KVCache, + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -242,7 +235,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -269,7 +262,6 @@ def __init__( ) -> None: super().__init__() self.config = config - self.unpadded_vocab_size = config.vocab_size self.linear_method = linear_method self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) @@ -281,7 +273,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: List[KVCache], + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches,