From ed46f143212203b7afcbc8538119b6e8155c643e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 25 Nov 2024 17:51:20 +0800 Subject: [PATCH] [Model] Support `is_causal` HF config field for Qwen2 model (#10621) Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 13 +++++++++--- .../embedding/language/test_embedding.py | 12 +++++++++-- tests/models/embedding/utils.py | 4 ++-- vllm/config.py | 15 ++++++++++---- vllm/model_executor/models/qwen2.py | 20 +++++++++++++++++-- 5 files changed, 51 insertions(+), 13 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ccd2d8de8ec0b..54e2c4479c2c9 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -342,7 +342,7 @@ Text Embedding - ✅︎ * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` - Qwen2-based - - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc. + - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - ✅︎ - ✅︎ * - :code:`RobertaModel`, :code:`RobertaForMaskedLM` @@ -363,6 +363,13 @@ Text Embedding .. tip:: You can override the model's pooling method by passing :code:`--override-pooler-config`. +.. note:: + Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. + You can set `--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. + + On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention + despite being described otherwise on its model card. + Reward Modeling --------------- @@ -606,10 +613,10 @@ Text Generation | :sup:`+` Multiple items can be inputted per text prompt for this modality. .. note:: - vLLM currently only supports adding LoRA to the language backbone of multimodal models. + vLLM currently only supports adding LoRA to the language backbone of multimodal models. .. note:: - For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. + The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 Multimodal Embedding diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index c3f351ef707be..36b1e5887981c 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -21,6 +21,7 @@ marks=[pytest.mark.core_model]), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), + pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), ], ) @pytest.mark.parametrize("dtype", ["half"]) @@ -31,6 +32,10 @@ def test_models( model, dtype: str, ) -> None: + vllm_extra_kwargs = {} + if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} + # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -43,8 +48,11 @@ def test_models( is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, task="embedding", dtype=dtype, - max_model_len=None) as vllm_model: + with vllm_runner(model, + task="embedding", + dtype=dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) # This test is for verifying whether the model's extra_repr # can be printed correctly. diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index fd1c44d9c117e..f96c7d2b176db 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -24,7 +24,7 @@ def check_embeddings_close( dim=0) fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{embeddings_0!r}" - f"\n{name_1}:\t{embeddings_1!r}") + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}") assert sim >= 1 - tol, fail_msg diff --git a/vllm/config.py b/vllm/config.py index 0a390c4311ba6..f9ecb02cd5bde 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,7 +27,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - identity, print_warning_once, resolve_obj_by_qualname) + print_warning_once, resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -183,7 +183,7 @@ def __init__( hf_overrides_fn = hf_overrides else: hf_overrides_kw = hf_overrides - hf_overrides_fn = identity + hf_overrides_fn = None if rope_scaling is not None: hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} @@ -212,8 +212,15 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format, **hf_overrides_kw) - hf_config = hf_overrides_fn(hf_config) + code_revision, config_format) + + if hf_overrides_kw: + logger.info("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.info("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 370cff5fa153f..8da75c9935a13 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -27,7 +27,7 @@ from torch import nn from transformers import Qwen2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -164,11 +164,17 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=attn_type) output, _ = self.o_proj(attn_output) return output @@ -210,6 +216,15 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + self._attn_type = AttentionType.DECODER + else: + self._attn_type = AttentionType.ENCODER_ONLY + def forward( self, positions: torch.Tensor, @@ -230,6 +245,7 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, + attn_type=self._attn_type, ) # Fully Connected