Skip to content

Commit

Permalink
[Frontend] Separate pooling APIs in offline inference (vllm-project#1…
Browse files Browse the repository at this point in the history
…1129)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Dec 13, 2024
1 parent f93bf2b commit eeec9e3
Show file tree
Hide file tree
Showing 21 changed files with 659 additions and 294 deletions.
7 changes: 5 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ steps:
commands:
- VLLM_USE_V1=1 pytest -v -s v1

- label: Examples Test # 15min
- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"
#mirror_hardwares: [amd]
source_file_dependencies:
- vllm/entrypoints
- examples/
commands:
- pip install awscli tensorizer # for llava example and tensorizer test
- pip install tensorizer # for tensorizer test
- python3 offline_inference.py
- python3 cpu_offload.py
- python3 offline_inference_chat.py
Expand All @@ -198,6 +198,9 @@ steps:
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_inference_classification.py
- python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py
- python3 offline_profile.py --model facebook/opt-125m

- label: Prefix Caching Test # 9min
Expand Down
53 changes: 45 additions & 8 deletions docs/source/models/pooling_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Pooling Models
vLLM also supports pooling models, including embedding, reranking and reward models.

In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface.
These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input
These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input
before returning them.

.. note::
Expand Down Expand Up @@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults.
^^^^^^^^^^^^^^

The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM.
It returns the aggregated hidden states directly.
It returns the extracted hidden states directly, which is useful for reward models.

.. code-block:: python
llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward")
output, = llm.encode("Hello, my name is")
data = output.outputs.data
print(f"Prompt: {prompt!r} | Data: {data!r}")
``LLM.embed``
^^^^^^^^^^^^^

The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt.
It is primarily designed for embedding models.

.. code-block:: python
llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed")
outputs = llm.encode("Hello, my name is")
output, = llm.embed("Hello, my name is")
outputs = model.encode(prompts)
for output in outputs:
embeddings = output.outputs.embedding
print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}")
embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
A code example can be found in `examples/offline_inference_embedding.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_embedding.py>`_.

``LLM.classify``
^^^^^^^^^^^^^^^^

The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt.
It is primarily designed for classification models.

.. code-block:: python
llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify")
output, = llm.classify("Hello, my name is")
probs = output.outputs.probs
print(f"Class Probabilities: {probs!r} (size={len(probs)})")
A code example can be found in `examples/offline_inference_classification.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_classification.py>`_.

``LLM.score``
^^^^^^^^^^^^^

Expand All @@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
To handle RAG at a higher level, you should use integration frameworks such as `LangChain <https://github.com/langchain-ai/langchain>`_.

You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/models/embedding/language/test_scoring.py>`_ as reference.
.. code-block:: python
llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score")
output, = llm.score("What is the capital of France?",
"The capital of Brazil is Brasilia.")
score = output.outputs.score
print(f"Score: {score}")
A code example can be found in `examples/offline_inference_scoring.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_scoring.py>`_.

Online Inference
----------------
Expand Down
28 changes: 28 additions & 0 deletions examples/offline_inference_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from vllm import LLM

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create an LLM.
# You should pass task="classify" for classification models
model = LLM(
model="jason9693/Qwen2.5-1.5B-apeach",
task="classify",
enforce_eager=True,
)

# Generate logits. The output is a list of ClassificationRequestOutputs.
outputs = model.classify(prompts)

# Print the outputs.
for prompt, output in zip(prompts, outputs):
probs = output.outputs.probs
probs_trimmed = ((str(probs[:16])[:-1] +
", ...]") if len(probs) > 16 else probs)
print(f"Prompt: {prompt!r} | "
f"Class Probabilities: {probs_trimmed} (size={len(probs)})")
16 changes: 11 additions & 5 deletions examples/offline_inference_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,20 @@
]

# Create an LLM.
# You should pass task="embed" for embedding models
model = LLM(
model="intfloat/e5-mistral-7b-instruct",
task="embed", # You should pass task="embed" for embedding models
task="embed",
enforce_eager=True,
)

# Generate embedding. The output is a list of PoolingRequestOutputs.
outputs = model.encode(prompts)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.embed(prompts)

# Print the outputs.
for output in outputs:
print(output.outputs.embedding) # list of 4096 floats
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = ((str(embeds[:16])[:-1] +
", ...]") if len(embeds) > 16 else embeds)
print(f"Prompt: {prompt!r} | "
f"Embeddings: {embeds_trimmed} (size={len(embeds)})")
23 changes: 23 additions & 0 deletions examples/offline_inference_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from vllm import LLM

# Sample prompts.
text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]

# Create an LLM.
# You should pass task="score" for cross-encoder models
model = LLM(
model="BAAI/bge-reranker-v2-m3",
task="score",
enforce_eager=True,
)

# Generate scores. The output is a list of ScoringRequestOutputs.
outputs = model.score(text_1, texts_2)

# Print the outputs.
for text_2, output in zip(texts_2, outputs):
score = output.outputs.score
print(f"Pair: {[text_1, text_2]!r} | Score: {score}")
2 changes: 1 addition & 1 deletion examples/offline_inference_vision_language_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality):
if req_data.image is not None:
mm_data["image"] = req_data.image

outputs = req_data.llm.encode({
outputs = req_data.llm.embed({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})
Expand Down
18 changes: 7 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,14 +719,6 @@ def get_inputs(

return inputs

def classify(self, prompts: List[str]) -> List[str]:
req_outputs = self.model.encode(prompts)
outputs = []
for req_output in req_outputs:
embedding = req_output.outputs.embedding
outputs.append(embedding)
return outputs

def generate(
self,
prompts: List[str],
Expand Down Expand Up @@ -897,6 +889,10 @@ def generate_beam_search(
returned_outputs.append((token_ids, texts))
return returned_outputs

def classify(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]

def encode(
self,
prompts: List[str],
Expand All @@ -909,16 +905,16 @@ def encode(
videos=videos,
audios=audios)

req_outputs = self.model.encode(inputs)
req_outputs = self.model.embed(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]

def score(
self,
text_1: Union[str, List[str]],
text_2: Union[str, List[str]],
) -> List[List[float]]:
) -> List[float]:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.embedding for req_output in req_outputs]
return [req_output.outputs.score for req_output in req_outputs]

def __enter__(self):
return self
Expand Down
10 changes: 5 additions & 5 deletions tests/entrypoints/openai/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
assert score.data[0].score[0] <= 0.01
assert score.data[1].score[0] >= 0.9
assert score.data[0].score <= 0.01
assert score.data[1].score >= 0.9


@pytest.mark.asyncio
Expand All @@ -67,8 +67,8 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
assert score.data[0].score[0] <= 0.01
assert score.data[1].score[0] >= 0.9
assert score.data[0].score <= 0.01
assert score.data[1].score >= 0.9


@pytest.mark.asyncio
Expand All @@ -90,4 +90,4 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
assert score.id is not None
assert score.data is not None
assert len(score.data) == 1
assert score.data[0].score[0] >= 0.9
assert score.data[0].score >= 0.9
10 changes: 5 additions & 5 deletions tests/models/embedding/language/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
assert len(vllm_outputs) == 1
assert len(hf_outputs) == 1

assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)


@pytest.mark.parametrize("dtype", ["half"])
Expand All @@ -63,8 +63,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2

assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)


@pytest.mark.parametrize("dtype", ["half"])
Expand All @@ -85,5 +85,5 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str):
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2

assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01)
assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01)
assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01)
5 changes: 2 additions & 3 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm import LLM, PoolingParams, SamplingParams
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

from ..utils import fork_new_process_for_each_test
Expand Down Expand Up @@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path):
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = PoolingParams()
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.encode(prompts, sampling_params)
outputs = llm.embed(prompts)

for output in outputs:
assert all(v == 0 for v in output.outputs.embedding)
Expand Down
36 changes: 11 additions & 25 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams

Expand All @@ -27,33 +30,16 @@
"CompletionOutput",
"PoolingOutput",
"PoolingRequestOutput",
"EmbeddingOutput",
"EmbeddingRequestOutput",
"ClassificationOutput",
"ClassificationRequestOutput",
"ScoringOutput",
"ScoringRequestOutput",
"LLMEngine",
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
]


def __getattr__(name: str):
import warnings

if name == "EmbeddingOutput":
msg = ("EmbeddingOutput has been renamed to PoolingOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingOutput

if name == "EmbeddingRequestOutput":
msg = ("EmbeddingRequestOutput has been renamed to "
"PoolingRequestOutput. "
"The original name will be removed in an upcoming version.")

warnings.warn(DeprecationWarning(msg), stacklevel=2)

return PoolingRequestOutput

raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading

0 comments on commit eeec9e3

Please sign in to comment.