From 8675bc1fbec80695be1e20c09407262b3052a2b9 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 1 Nov 2024 16:13:35 +0800 Subject: [PATCH] [Frontend] Chat-based Embeddings API (#9759) Signed-off-by: Tyler Michael Smith --- docs/requirements-docs.txt | 2 + docs/source/conf.py | 2 +- docs/source/dev/pooling_params.rst | 5 + docs/source/getting_started/quickstart.rst | 8 +- docs/source/index.rst | 1 + docs/source/models/vlm.rst | 54 ++++- .../serving/openai_compatible_server.md | 55 ++++- tests/entrypoints/openai/test_basic.py | 13 +- tests/entrypoints/openai/test_embedding.py | 137 +++++++---- tests/entrypoints/openai/test_metrics.py | 14 +- tests/entrypoints/openai/test_tokenization.py | 32 +-- .../openai/test_vision_embedding.py | 94 ++++++++ vllm/entrypoints/openai/api_server.py | 96 +++++--- vllm/entrypoints/openai/protocol.py | 87 ++++++- vllm/entrypoints/openai/run_batch.py | 34 ++- vllm/entrypoints/openai/serving_chat.py | 222 +++++++----------- vllm/entrypoints/openai/serving_completion.py | 75 +++--- vllm/entrypoints/openai/serving_embedding.py | 87 ++++--- vllm/entrypoints/openai/serving_engine.py | 159 ++++++++++++- .../openai/serving_tokenization.py | 87 +++---- vllm/pooling_params.py | 4 +- 21 files changed, 853 insertions(+), 415 deletions(-) create mode 100644 docs/source/dev/pooling_params.rst create mode 100644 tests/entrypoints/openai/test_vision_embedding.py diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index d58f226136918..e3e35844405ac 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -13,5 +13,7 @@ torch py-cpuinfo transformers mistral_common >= 1.3.4 +aiohttp +starlette openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 8435129e752e1..c7b638473a931 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -96,7 +96,6 @@ def setup(app): # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ - "aiohttp", "compressed_tensors", "cpuinfo", "cv2", @@ -143,6 +142,7 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: "python": ("https://docs.python.org/3", None), "typing_extensions": ("https://typing-extensions.readthedocs.io/en/latest", None), + "aiohttp": ("https://docs.aiohttp.org/en/stable", None), "pillow": ("https://pillow.readthedocs.io/en/stable", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), diff --git a/docs/source/dev/pooling_params.rst b/docs/source/dev/pooling_params.rst new file mode 100644 index 0000000000000..334e0287aff09 --- /dev/null +++ b/docs/source/dev/pooling_params.rst @@ -0,0 +1,5 @@ +Pooling Parameters +================== + +.. autoclass:: vllm.PoolingParams + :members: diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index f0e6cddf09ef7..00b762ccc2ccb 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -138,10 +138,10 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep A more detailed client example can be found `here `__. -OpenAI Chat API with vLLM -~~~~~~~~~~~~~~~~~~~~~~~~~~ +OpenAI Chat Completions API with vLLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -vLLM is designed to also support the OpenAI Chat API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations. +vLLM is designed to also support the OpenAI Chat Completions API. The chat interface is a more dynamic, interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations. You can use the `create chat completion `_ endpoint to interact with the model: @@ -157,7 +157,7 @@ You can use the `create chat completion `_ API, + Since OpenAI Vision API is based on `Chat Completions API `_, a chat template is **required** to launch the API server. Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it. @@ -243,6 +243,10 @@ To consume the server, you can use the OpenAI client like in the example below: A full code example can be found in `examples/openai_api_client_for_multimodal.py `_. +.. tip:: + There is no need to place image placeholders in the text content of the API request - they are already represented by the image content. + In fact, you can place image placeholders in the middle of the text by interleaving text and image content. + .. note:: By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable: @@ -251,5 +255,49 @@ A full code example can be found in `examples/openai_api_client_for_multimodal.p $ export VLLM_IMAGE_FETCH_TIMEOUT= -.. note:: - There is no need to format the prompt in the API request since it will be handled by the server. +Chat Embeddings API +^^^^^^^^^^^^^^^^^^^ + +vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API `_, +where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models. + +.. tip:: + The schema of ``messages`` is exactly the same as in Chat Completions API. + +In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model. + +.. code-block:: bash + + vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \ + --trust-remote-code --max-model-len 4096 + +.. important:: + + Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding`` + to run this model in embedding mode instead of text generation mode. + +Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library: + +.. code-block:: python + + import requests + + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + + response = requests.post( + "http://localhost:8000/v1/embeddings", + json={ + "model": "TIGER-Lab/VLM2Vec-Full", + "messages": [{ + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Represent the given image."}, + ], + }], + "encoding_format": "float", + }, + ) + response.raise_for_status() + response_json = response.json() + print("Embedding output:", response_json["data"][0]["embedding"]) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a1f93a9a28578..0b5f75caf2475 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -26,13 +26,26 @@ print(completion.choices[0].message) ``` ## API Reference -Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-reference) for more information on the API. We support all parameters except: -- Chat: `tools`, and `tool_choice`. -- Completions: `suffix`. -vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst). +We currently support the following OpenAI APIs: + +- [Completions API](https://platform.openai.com/docs/api-reference/completions) + - *Note: `suffix` parameter is not supported.* +- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) + - [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst). + - *Note: `image_url.detail` parameter is not supported.* + - We also support `audio_url` content type for audio files. + - Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema. + - *TODO: Support `input_audio` content type as defined [here](https://github.com/openai/openai-python/blob/v1.52.2/src/openai/types/chat/chat_completion_content_part_input_audio_param.py).* + - *Note: `parallel_tool_calls` and `user` parameters are ignored.* +- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings) + - Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API), + which will be treated as a single prompt to the model according to its chat template. + - This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst). + - *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.* ## Extra Parameters + vLLM supports a set of parameters that are not part of the OpenAI API. In order to use them, you can pass them as extra parameters in the OpenAI client. Or directly merge them into the JSON payload if you are using HTTP call directly. @@ -49,7 +62,26 @@ completion = client.chat.completions.create( ) ``` -### Extra Parameters for Chat API +### Extra Parameters for Completions API + +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-completion-sampling-params +:end-before: end-completion-sampling-params +``` + +The following extra parameters are supported: + +```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-completion-extra-params +:end-before: end-completion-extra-params +``` + +### Extra Parameters for Chat Completions API + The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py @@ -66,21 +98,22 @@ The following extra parameters are supported: :end-before: end-chat-completion-extra-params ``` -### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. +### Extra Parameters for Embeddings API + +The following [pooling parameters (click through to see documentation)](../dev/pooling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python -:start-after: begin-completion-sampling-params -:end-before: end-completion-sampling-params +:start-after: begin-embedding-pooling-params +:end-before: end-embedding-pooling-params ``` The following extra parameters are supported: ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python -:start-after: begin-completion-extra-params -:end-before: end-completion-extra-params +:start-after: begin-embedding-extra-params +:end-before: end-embedding-extra-params ``` ## Chat Template diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index d3aea533b6db9..4616f363cc04a 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -1,7 +1,6 @@ from http import HTTPStatus from typing import List -import openai import pytest import pytest_asyncio import requests @@ -83,10 +82,8 @@ async def client(server): indirect=True, ) @pytest.mark.asyncio -async def test_show_version(client: openai.AsyncOpenAI): - base_url = str(client.base_url)[:-3].strip("/") - - response = requests.get(base_url + "/version") +async def test_show_version(server: RemoteOpenAIServer): + response = requests.get(server.url_for("version")) response.raise_for_status() assert response.json() == {"version": VLLM_VERSION} @@ -102,9 +99,7 @@ async def test_show_version(client: openai.AsyncOpenAI): indirect=True, ) @pytest.mark.asyncio -async def test_check_health(client: openai.AsyncOpenAI): - base_url = str(client.base_url)[:-3].strip("/") - - response = requests.get(base_url + "/health") +async def test_check_health(server: RemoteOpenAIServer): + response = requests.get(server.url_for("health")) assert response.status_code == HTTPStatus.OK diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index f119c6c1201c9..9f2b77dde2a7f 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -4,14 +4,18 @@ import openai import pytest import pytest_asyncio +import requests + +from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer -EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" +DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 @pytest.fixture(scope="module") -def embedding_server(): +def server(): args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -19,31 +23,29 @@ def embedding_server(): "--enforce-eager", "--max-model-len", "8192", + "--chat-template", + DUMMY_CHAT_TEMPLATE, ] - with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @pytest_asyncio.fixture -async def embedding_client(embedding_server): - async with embedding_server.get_async_client() as async_client: +async def client(server): + async with server.get_async_client() as async_client: yield async_client @pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [EMBEDDING_MODEL_NAME], -) -async def test_single_embedding(embedding_client: openai.AsyncOpenAI, - model_name: str): +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str): input_texts = [ "The chef prepared a delicious meal.", ] # test single embedding - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_texts, encoding_format="float", @@ -57,7 +59,7 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI, # test using token IDs input_tokens = [1, 1, 1, 1, 1] - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_tokens, encoding_format="float", @@ -71,18 +73,14 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [EMBEDDING_MODEL_NAME], -) -async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, - model_name: str): +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str): # test List[str] input_texts = [ "The cat sat on the mat.", "A feline was resting on a rug.", "Stars twinkle brightly in the night sky." ] - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_texts, encoding_format="float", @@ -90,11 +88,14 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, assert embeddings.id is not None assert len(embeddings.data) == 3 assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 32 + assert embeddings.usage.total_tokens == 32 # test List[List[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], [25, 32, 64, 77]] - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_tokens, encoding_format="float", @@ -108,22 +109,70 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [EMBEDDING_MODEL_NAME], -) -async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_conversation_embedding(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "user", + "content": "The cat sat on the mat.", + }, { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }] + + chat_response = requests.post(server.url_for("v1/embeddings"), + json={ + "model": model_name, + "messages": messages, + "encoding_format": "float", + }) + chat_response.raise_for_status() + chat_embeddings = chat_response.json() + + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + prompt = tokenizer.apply_chat_template( + messages, + chat_template=DUMMY_CHAT_TEMPLATE, + add_generation_prompt=True, + continue_final_message=False, + tokenize=False, + ) + completion_response = await client.embeddings.create( + model=model_name, + input=prompt, + encoding_format="float", + # To be consistent with chat + extra_body={"add_special_tokens": False}, + ) + completion_embeddings = completion_response.model_dump(mode="json") + + assert chat_embeddings.pop("id") is not None + assert completion_embeddings.pop("id") is not None + assert chat_embeddings.pop("created") <= completion_embeddings.pop( + "created") + assert chat_embeddings == completion_embeddings + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_batch_base64_embedding(client: openai.AsyncOpenAI, model_name: str): input_texts = [ "Hello my name is", "The best thing about vLLM is that it supports many different models" ] - responses_float = await embedding_client.embeddings.create( - input=input_texts, model=model_name, encoding_format="float") + responses_float = await client.embeddings.create(input=input_texts, + model=model_name, + encoding_format="float") - responses_base64 = await embedding_client.embeddings.create( - input=input_texts, model=model_name, encoding_format="base64") + responses_base64 = await client.embeddings.create(input=input_texts, + model=model_name, + encoding_format="base64") decoded_responses_base64_data = [] for data in responses_base64.data: @@ -137,8 +186,8 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, 1] # Default response is float32 decoded from base64 by OpenAI Client - responses_default = await embedding_client.embeddings.create( - input=input_texts, model=model_name) + responses_default = await client.embeddings.create(input=input_texts, + model=model_name) assert responses_float.data[0].embedding == responses_default.data[ 0].embedding @@ -147,18 +196,15 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [EMBEDDING_MODEL_NAME], -) -async def test_single_embedding_truncation( - embedding_client: openai.AsyncOpenAI, model_name: str): +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_embedding_truncation(client: openai.AsyncOpenAI, + model_name: str): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] # test single embedding - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}) @@ -173,7 +219,7 @@ async def test_single_embedding_truncation( 1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728, 9901, 340, 2229, 385, 340, 315, 28741, 28804, 2 ] - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}) @@ -187,18 +233,15 @@ async def test_single_embedding_truncation( @pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [EMBEDDING_MODEL_NAME], -) -async def test_single_embedding_truncation_invalid( - embedding_client: openai.AsyncOpenAI, model_name: str): +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, + model_name: str): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] with pytest.raises(openai.BadRequestError): - embeddings = await embedding_client.embeddings.create( + embeddings = await client.embeddings.create( model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 8193}) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 6cb74eb78cbf0..b3f1fea91d13e 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -79,9 +79,8 @@ async def client(server): @pytest.mark.asyncio -async def test_metrics_counts(client: openai.AsyncOpenAI): - base_url = str(client.base_url)[:-3].strip("/") - +async def test_metrics_counts(server: RemoteOpenAIServer, + client: openai.AsyncClient): for _ in range(_NUM_REQUESTS): # sending a request triggers the metrics to be logged. await client.completions.create( @@ -89,7 +88,7 @@ async def test_metrics_counts(client: openai.AsyncOpenAI): prompt=_TOKENIZED_PROMPT, max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) - response = requests.get(base_url + "/metrics") + response = requests.get(server.url_for("metrics")) print(response.text) assert response.status_code == HTTPStatus.OK @@ -170,16 +169,15 @@ async def test_metrics_counts(client: openai.AsyncOpenAI): @pytest.mark.asyncio -async def test_metrics_exist(client: openai.AsyncOpenAI): - base_url = str(client.base_url)[:-3].strip("/") - +async def test_metrics_exist(server: RemoteOpenAIServer, + client: openai.AsyncClient): # sending a request triggers the metrics to be logged. await client.completions.create(model=MODEL_NAME, prompt="Hello, my name is", max_tokens=5, temperature=0.0) - response = requests.get(base_url + "/metrics") + response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK for metric in EXPECTED_METRICS: diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 859a676a9c777..b1956a8cbc9dc 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -1,4 +1,3 @@ -import openai # use the official client for correctness check import pytest import pytest_asyncio import requests @@ -55,9 +54,11 @@ async def client(server): [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], indirect=["tokenizer_name"], ) -async def test_tokenize_completions(client: openai.AsyncOpenAI, - model_name: str, tokenizer_name: str): - base_url = str(client.base_url)[:-3].strip("/") +async def test_tokenize_completions( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") @@ -65,7 +66,7 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI, prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(base_url + "/tokenize", + response = requests.post(server.url_for("tokenize"), json={ "add_special_tokens": add_special, "model": model_name, @@ -86,9 +87,11 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI, [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], indirect=["tokenizer_name"], ) -async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, - tokenizer_name: str): - base_url = str(client.base_url)[:-3].strip("/") +async def test_tokenize_chat( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") @@ -121,7 +124,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(base_url + "/tokenize", + response = requests.post(server.url_for("tokenize"), json={ "add_generation_prompt": add_generation, @@ -146,17 +149,18 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], indirect=["tokenizer_name"], ) -async def test_detokenize(client: openai.AsyncOpenAI, model_name: str, - tokenizer_name: str): - base_url = str(client.base_url)[:-3].strip("/") +async def test_detokenize( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) - print(f"CALLING {base_url} FOR {model_name}") - response = requests.post(base_url + "/detokenize", + response = requests.post(server.url_for("detokenize"), json={ "model": model_name, "tokens": tokens diff --git a/tests/entrypoints/openai/test_vision_embedding.py b/tests/entrypoints/openai/test_vision_embedding.py new file mode 100644 index 0000000000000..73a69da32e434 --- /dev/null +++ b/tests/entrypoints/openai/test_vision_embedding.py @@ -0,0 +1,94 @@ +from typing import Dict + +import pytest +import pytest_asyncio +import requests + +from vllm.multimodal.utils import encode_image_base64, fetch_image + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" +MAXIMUM_IMAGES = 2 + +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "embedding", + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "5", + "--enforce-eager", + "--trust-remote-code", + "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.fixture(scope="session") +def base64_encoded_image() -> Dict[str, str]: + return { + image_url: encode_image_base64(fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, + image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Represent the given image." + }, + ], + }] + + response = requests.post(server.url_for("v1/embeddings"), + json={ + "model": model_name, + "messages": messages, + "encoding_format": "float" + }) + response.raise_for_status() + + embeddings = response.json() + assert embeddings["id"] is not None + assert len(embeddings["data"]) == 1 + assert len(embeddings["data"][0]["embedding"]) == 3072 + assert embeddings["usage"]["completion_tokens"] == 0 + assert embeddings["usage"]["prompt_tokens"] == 771 + assert embeddings["usage"]["total_tokens"] == 771 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 46c92e10b360c..95fd56d916050 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Set +from typing import AsyncIterator, Optional, Set import uvloop from fastapi import APIRouter, FastAPI, Request @@ -51,7 +51,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -248,20 +248,25 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) -def chat(request: Request) -> OpenAIServingChat: +def base(request: Request) -> OpenAIServing: + # Reuse the existing instance + return tokenization(request) + + +def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat -def completion(request: Request) -> OpenAIServingCompletion: +def completion(request: Request) -> Optional[OpenAIServingCompletion]: return request.app.state.openai_serving_completion -def tokenization(request: Request) -> OpenAIServingTokenization: - return request.app.state.openai_serving_tokenization +def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: + return request.app.state.openai_serving_embedding -def embedding(request: Request) -> OpenAIServingEmbedding: - return request.app.state.openai_serving_embedding +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization def engine_client(request: Request) -> EngineClient: @@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response: @router.post("/tokenize") async def tokenize(request: TokenizeRequest, raw_request: Request): - generator = await tokenization(raw_request).create_tokenize(request) + handler = tokenization(raw_request) + + generator = await handler.create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): @router.post("/detokenize") async def detokenize(request: DetokenizeRequest, raw_request: Request): - generator = await tokenization(raw_request).create_detokenize(request) + handler = tokenization(raw_request) + + generator = await handler.create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): - models = await completion(raw_request).show_available_models() + handler = base(raw_request) + + models = await handler.show_available_models() return JSONResponse(content=models.model_dump()) @@ -314,9 +325,12 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + handler = chat(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Chat Completions API") - generator = await chat(raw_request).create_chat_completion( - request, raw_request) + generator = await handler.create_chat_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await completion(raw_request).create_completion( - request, raw_request) + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await embedding(raw_request).create_embedding( - request, raw_request) + handler = embedding(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + generator = await handler.create_embedding(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -382,30 +404,26 @@ async def stop_profile(raw_request: Request): @router.post("/v1/load_lora_adapter") async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): - response = await chat(raw_request).load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - response = await completion(raw_request).load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): - response = await chat(raw_request).unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) - - response = await completion(raw_request).unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + for route in [chat, completion, embedding]: + handler = route(raw_request) + if handler is not None: + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @@ -501,7 +519,8 @@ def init_app_state( chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser) + tool_parser=args.tool_call_parser, + ) if model_config.task == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, @@ -510,13 +529,14 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + ) if model_config.task == "generate" else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, base_model_paths, request_logger=request_logger, - ) + chat_template=args.chat_template, + ) if model_config.task == "embedding" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 60fc5ac8d11d2..1335e51bd152c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -708,7 +708,7 @@ def validate_stream_options(cls, data): return data -class EmbeddingRequest(OpenAIBaseModel): +class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings model: str @@ -720,10 +720,15 @@ class EmbeddingRequest(OpenAIBaseModel): # doc: begin-embedding-pooling-params additional_data: Optional[Any] = None - # doc: end-embedding-pooling-params # doc: begin-embedding-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) priority: int = Field( default=0, description=( @@ -737,6 +742,82 @@ def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) +class EmbeddingChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-chat-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-chat-embedding-pooling-params + + # doc: begin-chat-embedding-extra-params + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + # doc: end-chat-embedding-extra-params + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] + + class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -799,7 +880,7 @@ class EmbeddingResponseData(OpenAIBaseModel): class EmbeddingResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) model: str diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index f5249a0c447b3..a64467a311523 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -217,13 +217,14 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, - ) + ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, base_model_paths, request_logger=request_logger, - ) + chat_template=None, + ) if model_config.task == "embedding" else None tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) @@ -240,14 +241,31 @@ async def main(args): # Determine the type of request and run it. if request.url == "/v1/chat/completions": - response_futures.append( - run_request(openai_serving_chat.create_chat_completion, - request, tracker)) + handler_fn = (None if openai_serving_chat is None else + openai_serving_chat.create_chat_completion) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg= + "The model does not support Chat Completions API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": - response_futures.append( - run_request(openai_serving_embedding.create_embedding, request, - tracker)) + handler_fn = (None if openai_serving_embedding is None else + openai_serving_embedding.create_embedding) + if handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Embeddings API", + )) + continue + + response_futures.append(run_request(handler_fn, request, tracker)) tracker.submitted() else: response_futures.append( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1f951d15a7a32..9551b4f2091dd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,11 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - load_chat_template, - parse_chat_messages_futures) +from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -27,16 +23,12 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, - PromptAdapterPath, - TextTokensPrompt) + PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager -from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob -from vllm.tracing import (contains_trace_headers, extract_trace_headers, - log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import iterate_with_cancellation @@ -94,12 +86,12 @@ async def create_chat_completion( raw_request: Optional[Request] = None, ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: - """Completion API similar to OpenAI's API. + """ + Chat Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create for the API specification. This API mimics the OpenAI - ChatCompletion API. - + Chat Completion API. """ error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -118,143 +110,106 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - model_config = self.model_config tokenizer = await self.engine_client.get_tokenizer(lora_request) - - conversation, mm_data_future = parse_chat_messages_futures( - request.messages, model_config, tokenizer) + tool_parser = self.tool_parser + + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + if (request.tool_choice == "auto" and + not (self.enable_auto_tools and tool_parser is not None) + and not isinstance(tokenizer, MistralTokenizer)): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set" + ) tool_dicts = None if request.tools is None else [ tool.model_dump() for tool in request.tools ] - prompt: Union[str, List[int]] - is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) - if is_mistral_tokenizer: - prompt = apply_mistral_chat_template( - tokenizer, - messages=request.messages, - chat_template=request.chat_template or self.chat_template, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tools=tool_dicts, - documents=request.documents, - **(request.chat_template_kwargs or {}), - ) - else: - prompt = apply_hf_chat_template( - tokenizer, - conversation=conversation, - chat_template=request.chat_template or self.chat_template, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, - tools=tool_dicts, - documents=request.documents, - **(request.chat_template_kwargs or {}), - ) - except Exception as e: - logger.exception("Error in applying chat template from request") - return self.create_error_response(str(e)) - - try: - mm_data = await mm_data_future - except Exception as e: - logger.exception("Error in loading multi-modal data") + ( + conversation, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tool_dicts=tool_dicts, + documents=request.documents, + chat_template_kwargs=request.chat_template_kwargs, + tool_parser=tool_parser, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - # validation for OpenAI tools - # tool_choice = "required" is not supported - if request.tool_choice == "required": - return self.create_error_response( - "tool_choice = \"required\" is not supported!") - - if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( - self.enable_auto_tools and self.tool_parser is not None): - # for hf tokenizers, "auto" tools requires - # --enable-auto-tool-choice and --tool-call-parser - return self.create_error_response( - "\"auto\" tool choice requires " - "--enable-auto-tool-choice and --tool-call-parser to be set") - - request_id = f"chat-{request.request_id}" + request_id = f"chatcmpl-{request.request_id}" request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: raw_request.state.request_metadata = request_metadata + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[RequestOutput, None]] = [] try: - if self.enable_auto_tools and self.tool_parser: - request = self.tool_parser(tokenizer).adjust_request( - request=request) - - if isinstance(prompt, str): - prompt_inputs = self._tokenize_prompt_input( - request, - tokenizer, - prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) - else: - assert isinstance(prompt, list) and isinstance( - prompt[0], int - ), "Prompt has to be either a string or a list of token ids" - prompt_inputs = TextTokensPrompt( - prompt=tokenizer.decode(prompt), prompt_token_ids=prompt) - - assert prompt_inputs is not None - - sampling_params: Union[SamplingParams, BeamSearchParams] - default_max_tokens = self.max_model_len - len( - prompt_inputs["prompt_token_ids"]) - if request.use_beam_search: - sampling_params = request.to_beam_search_params( - default_max_tokens) - else: - sampling_params = request.to_sampling_params( - default_max_tokens) - - self._log_inputs(request_id, - prompt_inputs, - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) - - engine_inputs = TokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"]) - if mm_data is not None: - engine_inputs["multi_modal_data"] = mm_data - - is_tracing_enabled = (await - self.engine_client.is_tracing_enabled()) - trace_headers = None - if is_tracing_enabled and raw_request: - trace_headers = extract_trace_headers(raw_request.headers) - if (not is_tracing_enabled and raw_request - and contains_trace_headers(raw_request.headers)): - log_tracing_disabled_warning() - - if isinstance(sampling_params, BeamSearchParams): - result_generator = self.engine_client.beam_search( - prompt=engine_inputs, - model_config=self.model_config, - request_id=request_id, - params=sampling_params, - ) - else: - result_generator = self.engine_client.generate( - engine_inputs, - sampling_params, - request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=request.priority, - ) + for i, engine_prompt in enumerate(engine_prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens) + else: + sampling_params = request.to_sampling_params( + default_max_tokens) + + self._log_inputs(request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + if isinstance(sampling_params, BeamSearchParams): + generator = self.engine_client.beam_search( + prompt=engine_prompt, + model_config=self.model_config, + request_id=request_id, + params=sampling_params, + ) + else: + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + assert len(generators) == 1 + result_generator, = generators + if raw_request: result_generator = iterate_with_cancellation( result_generator, raw_request.is_disconnected) @@ -626,6 +581,9 @@ async def chat_completion_full_generator( final_res = res except asyncio.CancelledError: return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) assert final_res is not None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index da521a6012530..570232be38379 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,7 +1,6 @@ import asyncio import time -from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Optional) +from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple, Union, cast @@ -30,18 +29,11 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob -from vllm.tracing import (contains_trace_headers, extract_trace_headers, - log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) -TypeTokenIDs = List[int] -TypeTopLogProbs = List[Optional[Dict[int, float]]] -TypeCreateLogProbsFn = Callable[ - [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] - class OpenAIServingCompletion(OpenAIServing): @@ -101,8 +93,6 @@ async def create_completion( if raw_request: raw_request.state.request_metadata = request_metadata - # Schedule the request and get the result generator. - generators: List[AsyncGenerator[RequestOutput, None]] = [] try: ( lora_request, @@ -111,19 +101,24 @@ async def create_completion( tokenizer = await self.engine_client.get_tokenizer(lora_request) - prompts = list( - self._tokenize_prompt_input_or_inputs( - request, - tokenizer, - request.prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - )) + request_prompts, engine_prompts = self._preprocess_completion( + request, + tokenizer, + request.prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) - for i, prompt_inputs in enumerate(prompts): + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): sampling_params: Union[SamplingParams, BeamSearchParams] default_max_tokens = self.max_model_len - len( - prompt_inputs["prompt_token_ids"]) + engine_prompt["prompt_token_ids"]) if request.use_beam_search: sampling_params = request.to_beam_search_params( default_max_tokens) @@ -134,36 +129,24 @@ async def create_completion( request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, - prompt_inputs, + request_prompts[i], params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - is_tracing_enabled = (await - self.engine_client.is_tracing_enabled()) - trace_headers = None - if is_tracing_enabled: - trace_headers = extract_trace_headers(raw_request.headers) - if not is_tracing_enabled and contains_trace_headers( - raw_request.headers): - log_tracing_disabled_warning() + trace_headers = (await + self._get_trace_headers(raw_request.headers)) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( - prompt={ - "prompt_token_ids": - prompt_inputs["prompt_token_ids"] - }, + prompt=engine_prompt, model_config=self.model_config, request_id=request_id, params=sampling_params, ) else: generator = self.engine_client.generate( - { - "prompt_token_ids": - prompt_inputs["prompt_token_ids"] - }, + engine_prompt, sampling_params, request_id_item, lora_request=lora_request, @@ -180,6 +163,8 @@ async def create_completion( result_generator = merge_async_iterators( *generators, is_cancelled=raw_request.is_disconnected) + num_prompts = len(engine_prompts) + # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use # beam search. @@ -195,16 +180,22 @@ async def create_completion( request_id, created_time, model_name, - num_prompts=len(prompts), + num_prompts=num_prompts, tokenizer=tokenizer, request_metadata=request_metadata) # Non-streaming response - final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts try: async for i, res in result_generator: final_res_batch[i] = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + try: for i, final_res in enumerate(final_res_batch): assert final_res is not None @@ -212,7 +203,7 @@ async def create_completion( # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - final_res.prompt = prompts[i]["prompt"] + final_res.prompt = request_prompts[i]["prompt"] final_res_batch_checked = cast(List[RequestOutput], final_res_batch) @@ -226,8 +217,6 @@ async def create_completion( tokenizer, request_metadata, ) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6c46aae2838f6..917856cd2b2dd 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -9,8 +9,10 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (EmbeddingRequest, +from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, + EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) @@ -21,8 +23,6 @@ logger = init_logger(__name__) -TypeTokenIDs = List[int] - def _get_embedding( output: EmbeddingOutput, @@ -76,6 +76,7 @@ def __init__( base_model_paths: List[BaseModelPath], *, request_logger: Optional[RequestLogger], + chat_template: Optional[str], ): super().__init__(engine_client=engine_client, model_config=model_config, @@ -83,21 +84,20 @@ def __init__( lora_modules=None, prompt_adapters=None, request_logger=request_logger) - self._enabled = self._check_embedding_mode( - model_config.task == "embedding") + + self.chat_template = load_chat_template(chat_template) async def create_embedding( self, request: EmbeddingRequest, raw_request: Optional[Request] = None, ) -> Union[EmbeddingResponse, ErrorResponse]: - """Completion API similar to OpenAI's API. + """ + Embedding API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/embeddings/create for the API specification. This API mimics the OpenAI Embedding API. """ - if not self._enabled: - return self.create_error_response("Embedding API disabled") error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -122,8 +122,6 @@ async def create_embedding( "greater than max_model_len." " Please, select a smaller truncation size.") - # Schedule the request and get the result generator. - generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] try: ( lora_request, @@ -132,32 +130,60 @@ async def create_embedding( tokenizer = await self.engine_client.get_tokenizer(lora_request) - pooling_params = request.to_pooling_params() + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(request, EmbeddingChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, + tokenizer, + request.messages, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + request_prompts, engine_prompts = self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) - prompts = list( - self._tokenize_prompt_input_or_inputs(request, tokenizer, - request.input, - truncate_prompt_tokens)) + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + try: + pooling_params = request.to_pooling_params() - for i, prompt_inputs in enumerate(prompts): + for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, - prompt_inputs, + request_prompts[i], params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - if prompt_adapter_request is not None: - raise NotImplementedError( - "Prompt adapter is not supported " - "for embedding models") + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) generator = self.engine_client.encode( - {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, + engine_prompt, pooling_params, request_id_item, lora_request=lora_request, + trace_headers=trace_headers, priority=request.priority, ) @@ -171,13 +197,18 @@ async def create_embedding( is_cancelled=raw_request.is_disconnected if raw_request else None, ) + num_prompts = len(engine_prompts) + # Non-streaming response final_res_batch: List[Optional[EmbeddingRequestOutput]] - final_res_batch = [None] * len(prompts) + final_res_batch = [None] * num_prompts try: async for i, res in result_generator: final_res_batch[i] = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + try: for final_res in final_res_batch: assert final_res is not None @@ -187,18 +218,8 @@ async def create_embedding( response = request_output_to_embedding_response( final_res_batch_checked, request_id, created_time, model_name, encoding_format) - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) return response - - def _check_embedding_mode(self, embedding_mode: bool) -> bool: - if not embedding_mode: - logger.warning( - "embedding_mode is False. Embedding API will not work.") - else: - logger.info("Activating the server engine with embedding enabled.") - return embedding_mode diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 22a01b3dc4cc0..e7aeac8f8c018 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,28 +2,38 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union +from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, + Optional, Sequence, Tuple, TypedDict, Union) from pydantic import Field +from starlette.datastructures import Headers from typing_extensions import Annotated from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, - EmbeddingRequest, ErrorResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + ErrorResponse, LoadLoraAdapterRequest, ModelCard, ModelList, ModelPermission, TokenizeChatRequest, TokenizeCompletionRequest, - TokenizeRequest, UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable +from vllm.inputs import TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -31,8 +41,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import AtomicCounter +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import AtomicCounter, is_list_of logger = init_logger(__name__) @@ -56,8 +68,14 @@ class LoRAModulePath: base_model_name: Optional[str] = None -AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, - EmbeddingRequest, TokenizeRequest] +CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, + EmbeddingCompletionRequest, + TokenizeCompletionRequest] + +ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, + TokenizeChatRequest] + +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest] class TextTokensPrompt(TypedDict): @@ -65,6 +83,9 @@ class TextTokensPrompt(TypedDict): prompt_token_ids: List[int] +RequestPrompt = Union[List[int], str, TextTokensPrompt] + + class OpenAIServing: def __init__( @@ -246,7 +267,8 @@ def _validate_input( token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens - if isinstance(request, EmbeddingRequest): + if isinstance(request, + (EmbeddingChatRequest, EmbeddingCompletionRequest)): if token_num > self.max_model_len: raise ValueError( f"This model's maximum context length is " @@ -373,10 +395,115 @@ def _tokenize_prompt_input_or_inputs( truncate_prompt_tokens=truncate_prompt_tokens, ) + def _preprocess_completion( + self, + request: CompletionLikeRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]: + request_prompts = [ + request_prompt + for request_prompt in self._tokenize_prompt_input_or_inputs( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + ] + + engine_prompts = [ + TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) + for request_prompt in request_prompts + ] + + return request_prompts, engine_prompts + + async def _preprocess_chat( + self, + request: ChatLikeRequest, + tokenizer: AnyTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str] = None, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tool_dicts: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Dict[str, str]]] = None, + chat_template_kwargs: Optional[Dict[str, Any]] = None, + tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = False, + ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt], + List[TokensPrompt]]: + conversation, mm_data_future = parse_chat_messages_futures( + messages, + self.model_config, + tokenizer, + ) + + request_prompt: Union[str, List[int]] + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: + request_prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + **(chat_template_kwargs or {}), + ) + else: + request_prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tool_dicts, + documents=documents, + **(chat_template_kwargs or {}), + ) + + mm_data = await mm_data_future + + if tool_parser is not None: + if not isinstance(request, ChatCompletionRequest): + msg = "Tool usage is only supported for Chat Completions API" + raise NotImplementedError(msg) + + request = tool_parser(tokenizer).adjust_request(request=request) + + if isinstance(request_prompt, str): + prompt_inputs = self._tokenize_prompt_input( + request, + tokenizer, + request_prompt, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + # For MistralTokenizer + assert is_list_of(request_prompt, int), ( + "Prompt has to be either a string or a list of token ids") + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(request_prompt), + prompt_token_ids=request_prompt) + + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + + return conversation, [request_prompt], [engine_prompt] + def _log_inputs( self, request_id: str, - inputs: Union[str, List[int], TextTokensPrompt], + inputs: RequestPrompt, params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], @@ -404,6 +531,20 @@ def _log_inputs( prompt_adapter_request=prompt_adapter_request, ) + async def _get_trace_headers( + self, + headers: Headers, + ) -> Optional[Mapping[str, str]]: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + @staticmethod def _get_decoded_token(logprob: Logprob, token_id: int, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index a269c94c7ec0d..1fd82304f7a4d 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,10 +2,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (apply_hf_chat_template, - apply_mistral_chat_template, - load_chat_template, - parse_chat_messages_futures) +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -20,7 +17,6 @@ LoRAModulePath, OpenAIServing) from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @@ -62,59 +58,51 @@ async def create_tokenize( request_id = f"tokn-{random_uuid()}" - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) - - tokenizer = await self.engine_client.get_tokenizer(lora_request) - - prompt: Union[str, List[int]] - if isinstance(request, TokenizeChatRequest): - model_config = self.model_config - - conversation, mm_data_future = parse_chat_messages_futures( - request.messages, model_config, tokenizer) - - mm_data = await mm_data_future - if mm_data: - logger.warning( - "Multi-modal inputs are ignored during tokenization") - - if isinstance(tokenizer, MistralTokenizer): - prompt = apply_mistral_chat_template( + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if isinstance(request, TokenizeChatRequest): + ( + _, + request_prompts, + engine_prompts, + ) = await self._preprocess_chat( + request, tokenizer, - messages=request.messages, + request.messages, chat_template=self.chat_template, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, + add_special_tokens=request.add_special_tokens, ) else: - prompt = apply_hf_chat_template( + request_prompts, engine_prompts = self._preprocess_completion( + request, tokenizer, - conversation=conversation, - chat_template=self.chat_template, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, + request.prompt, + add_special_tokens=request.add_special_tokens, ) - else: - prompt = request.prompt + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) - self._log_inputs(request_id, - prompt, - params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + input_ids: List[int] = [] + for i, engine_prompt in enumerate(engine_prompts): + self._log_inputs(request_id, + request_prompts[i], + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) - # Silently ignore prompt adapter since it does not affect tokenization + # Silently ignore prompt adapter since it does not affect + # tokenization (Unlike in Embeddings API where an error is raised) - prompt_input = self._tokenize_prompt_input( - request, - tokenizer, - prompt, - add_special_tokens=request.add_special_tokens, - ) - input_ids = prompt_input["prompt_token_ids"] + input_ids.extend(engine_prompt["prompt_token_ids"]) return TokenizeResponse(tokens=input_ids, count=len(input_ids), @@ -143,9 +131,8 @@ async def create_detokenize( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - if prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for tokenization") + # Silently ignore prompt adapter since it does not affect tokenization + # (Unlike in Embeddings API where an error is raised) prompt_input = self._tokenize_prompt_input( request, diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 7461fb51989c6..2635c0bccd1c4 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -7,7 +7,7 @@ class PoolingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] - """Pooling parameters for pooling. + """Pooling parameters for embeddings API. Attributes: additional_data: Any additional data needed for pooling. @@ -16,7 +16,7 @@ class PoolingParams( def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams(additional_data=self.additional_data, ) + return PoolingParams(additional_data=self.additional_data) def __repr__(self) -> str: return (f"PoolingParams("