Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Separate pooling APIs in offline inference #11129

Merged
merged 19 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ steps:
commands:
- VLLM_USE_V1=1 pytest -v -s v1

- label: Examples Test # 15min
- label: Examples Test # 20min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/entrypoints
- examples/
commands:
- pip install awscli tensorizer # for llava example and tensorizer test
- pip install tensorizer # for tensorizer test
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
Expand All @@ -198,6 +198,9 @@ steps:
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_inference_classification.py
- python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py
- python3 offline_profile.py --model facebook/opt-125m

- label: Prefix Caching Test # 9min
Expand Down
53 changes: 45 additions & 8 deletions docs/source/models/pooling_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Pooling Models
vLLM also supports pooling models, including embedding, reranking and reward models.

In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input
before returning them.

.. note::
Expand Down Expand Up @@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults.
^^^^^^^^^^^^^^

The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
It returns the aggregated hidden states directly.
It returns the extracted hidden states directly, which is useful for reward models.

.. code-block:: python

llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
output, = llm.encode("Hello, my name is")

data = output.outputs.data
print(f"Prompt: {prompt!r} | Data: {data!r}")

``LLM.embed``
^^^^^^^^^^^^^

The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt.
It is primarily designed for embedding models.

.. code-block:: python

llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
outputs = llm.encode("Hello, my name is")
output, = llm.embed("Hello, my name is")

outputs = model.encode(prompts)
for output in outputs:
embeddings = output.outputs.embedding
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")

A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.

``LLM.classify``
^^^^^^^^^^^^^^^^

The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt.
It is primarily designed for classification models.

.. code-block:: python

llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
output, = llm.classify("Hello, my name is")

probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")

A code example can be found in `examples/offline_inference_classification.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_classification.py>`_.

``LLM.score``
^^^^^^^^^^^^^
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.

You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
.. code-block:: python

llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
output, = llm.score("What is the capital of France?",
"The capital of Brazil is Brasilia.")

probs = output.outputs.probs
print(f"Scores: {probs!r} (size={len(probs)})")

A code example can be found in `examples/offline_inference_scoring.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_scoring.py>`_.

Online Inference
----------------
Expand Down
28 changes: 28 additions & 0 deletions examples/offline_inference_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from vllm import LLM

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM.
# You should pass task="classify" for classification models
model = LLM(
model="jason9693/Qwen2.5-1.5B-apeach",
task="classify",
enforce_eager=True,
)

# Generate logits. The output is a list of ClassificationRequestOutputs.
outputs = model.classify(prompts)

# Print the outputs.
for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs
probs_trimmed = ((str(probs[:16])[:-1] +
", ...]") if len(probs) > 16 else probs)
print(f"Prompt: {prompt!r} | "
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
16 changes: 11 additions & 5 deletions examples/offline_inference_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
]

# Create an LLM.
# You should pass task="embed" for embedding models
model = LLM(
model="intfloat/e5-mistral-7b-instruct",
task="embed", # You should pass task="embed" for embedding models
task="embed",
enforce_eager=True,
)

# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts)

# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} | "
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
26 changes: 26 additions & 0 deletions examples/offline_inference_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from vllm import LLM

# Sample prompts.
text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]

# Create an LLM.
# You should pass task="score" for cross-encoder models
model = LLM(
model="BAAI/bge-reranker-v2-m3",
task="score",
enforce_eager=True,
)

# Generate scores. The output is a list of ClassificationRequestOutputs.
outputs = model.score(text_1, texts_2)

# Print the outputs.
for text_2, output in zip(texts_2, outputs):
scores = output.outputs.probs
scores_trimmed = ((str(scores[:16])[:-1] +
", ...]") if len(scores) > 16 else scores)
print(f"Pair: {[text_1, text_2]!r} | "
f"Scores: {scores_trimmed} (size={len(scores)})")
2 changes: 1 addition & 1 deletion examples/offline_inference_vision_language_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality):
if req_data.image is not None:
mm_data["image"] = req_data.image

outputs = req_data.llm.encode({
outputs = req_data.llm.embed({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})
Expand Down
16 changes: 6 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,14 +719,6 @@ def get_inputs(

return inputs

def classify(self, prompts: List[str]) -> List[str]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs

def generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -897,6 +889,10 @@ def generate_beam_search(
returned_outputs.append((token_ids, texts))
return returned_outputs

def classify(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]

def encode(
self,
prompts: List[str],
Expand All @@ -909,7 +905,7 @@ def encode(
videos=videos,
audios=audios)

req_outputs = self.model.encode(inputs)
req_outputs = self.model.embed(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]

def score(
Expand All @@ -918,7 +914,7 @@ def score(
text_2: Union[str, List[str]],
) -> List[List[float]]:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.embedding for req_output in req_outputs]
return [req_output.outputs.probs for req_output in req_outputs]

def __enter__(self):
return self
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm import LLM, PoolingParams, SamplingParams
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

from ..utils import fork_new_process_for_each_test
Expand Down Expand Up @@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path):
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = PoolingParams()
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.encode(prompts, sampling_params)
outputs = llm.embed(prompts)

for output in outputs:
assert all(v == 0 for v in output.outputs.embedding)
Expand Down
31 changes: 7 additions & 24 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, PoolingOutput,
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
Expand All @@ -27,33 +29,14 @@
"CompletionOutput",
"PoolingOutput",
"PoolingRequestOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"ClassificationOutput",
"ClassificationRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]


def __getattr__(name: str):
import warnings

if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingOutput

if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingRequestOutput

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
17 changes: 8 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
ParallelSampleSequenceGroup, Sequence,
SequenceGroup, SequenceGroupBase,
SequenceGroupMetadata, SequenceGroupOutput,
SequenceStatus)
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
PoolingSequenceGroupOutput, Sequence, SequenceGroup,
SequenceGroupBase, SequenceGroupMetadata,
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
Expand Down Expand Up @@ -966,9 +965,9 @@ def has_unfinished_requests_for_virtual_engine(
@staticmethod
def _process_sequence_group_outputs(
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
outputs: List[PoolingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings
seq_group.pooled_data = outputs[0].data

for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
Expand Down Expand Up @@ -1784,8 +1783,8 @@ def _get_stats(self,
num_prompt_tokens_iter)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
if model_output and (model_output[0].spec_decode_worker_metrics
is not None):
if model_output and isinstance(model_output[0], SamplerOutput) and (
model_output[0].spec_decode_worker_metrics is not None):
Comment on lines +1786 to +1787
Copy link
Member Author

@DarkLight1337 DarkLight1337 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since spec decode isn't applicable to pooling models, I have removed spec_decode_worker_metrics from PoolerOutput. The type annotation that model_output is a list of SamplerOutputs is actually incorrect here (it can be a list of PoolerOutput) but I'm not bothered to fix it since we will probably rework this in V1 anyways.

spec_decode_metrics = model_output[0].spec_decode_worker_metrics
else:
spec_decode_metrics = None
Expand Down
Loading
Loading