Skip to content

Commit

Permalink
[Model] Add classification Task with Qwen2ForSequenceClassification (v…
Browse files Browse the repository at this point in the history
…llm-project#9704)

Signed-off-by: Kevin-Yang <[email protected]>
Co-authored-by: Kevin-Yang <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
  • Loading branch information
2 people authored and cooleel committed Oct 28, 2024
1 parent c3449f9 commit d5668e9
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 1 deletion.
22 changes: 22 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,28 @@ Reward Modeling
.. note::
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.

Classification
---------------

.. list-table::
:widths: 25 25 50 5 5
:header-rows: 1

* - Architecture
- Models
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`Qwen2ForSequenceClassification`
- Qwen2-based
- :code:`jason9693/Qwen2.5-1.5B-apeach`, etc.
-
- ✅︎

.. note::
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).


Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,17 @@ def get_inputs(

return all_inputs

def classify(self, prompts: List[str]) -> List[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
logits = output.logits.softmax(dim=-1)[0].tolist()
outputs.append(logits)

return outputs

def generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -688,6 +699,14 @@ def get_inputs(

return inputs

def classify(self, prompts: List[str]) -> List[str]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs

def generate(
self,
prompts: List[str],
Expand Down
53 changes: 53 additions & 0 deletions tests/models/embedding/language/test_cls_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_cls_models.py`.
"""
import pytest
import torch
from transformers import AutoModelForSequenceClassification

CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"]


@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_classification_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
with hf_runner(model,
dtype=dtype,
auto_cls=AutoModelForSequenceClassification) as hf_model:
hf_outputs = hf_model.classify(example_prompts)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

print(hf_outputs, vllm_outputs)

# check logits difference
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)

assert torch.allclose(hf_output, vllm_output, 1e-3)


@pytest.mark.parametrize("model", CLASSIFICATION_MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_classification_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ class Pooler(nn.Module):
normalize: Whether to normalize the pooled data.
"""

def __init__(self, pooling_type: PoolingType, normalize: bool):
def __init__(self,
pooling_type: PoolingType,
normalize: bool,
softmax: bool = False):
super().__init__()

self.pooling_type = pooling_type
self.normalize = normalize
self.softmax = softmax

def forward(
self,
Expand Down Expand Up @@ -64,6 +68,9 @@ def forward(
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)

if self.softmax:
pooled_data = nn.functional.softmax(pooled_data, dim=-1)

pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
]
Expand Down
107 changes: 107 additions & 0 deletions vllm/model_executor/models/qwen2_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 Kakao Corp. (Kanana-X Team)
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-Classification model compatible with HF weights."""
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn
from transformers import Qwen2Config

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .utils import AutoWeightsLoader


class Qwen2ForSequenceClassification(nn.Module):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
raise ValueError("Sliding window for some but all layers is not "
"supported. This model uses sliding window "
"but `max_window_layers` = %s is less than "
"`num_hidden_layers` = %s. Please open an issue "
"to discuss this feature." % (
config.max_window_layers,
config.num_hidden_layers,
))

super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)

self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config)
self._pooler = Pooler(pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
logits, _ = self.score(hidden_states)
return logits

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
loader.load_weights(weights)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": (
"qwen2_cls", "Qwen2ForSequenceClassification"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
Expand Down

0 comments on commit d5668e9

Please sign in to comment.