Skip to content

Commit

Permalink
[Frontend] Rerank API (Jina- and Cohere-compatible API) (vllm-project…
Browse files Browse the repository at this point in the history
…#12376)

Signed-off-by: Kyle Mistele <[email protected]>
  • Loading branch information
K-Mistele authored Jan 27, 2025
1 parent 72bac73 commit 0034b09
Show file tree
Hide file tree
Showing 9 changed files with 552 additions and 11 deletions.
92 changes: 92 additions & 0 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ In addition, we have the following custom APIs:
- Applicable to all [pooling models](../models/pooling_models.md).
- [Score API](#score-api) (`/score`)
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
- Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response.
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).

(chat-template)=

Expand Down Expand Up @@ -473,3 +478,90 @@ The following extra parameters are supported:
:start-after: begin-score-extra-params
:end-before: end-score-extra-params
```

(rerank-api)=

### Re-rank API

Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
a scale of 0 to 1.

You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).

The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
popular open-source tools.

Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py>

#### Example Request

Note that the `top_n` request parameter is optional and will default to the length of the `documents` field.
Result documents will be sorted by relevance, and the `index` property can be used to determine original order.

Request:

```bash
curl -X 'POST' \
'http://127.0.0.1:8000/v1/rerank' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "BAAI/bge-reranker-base",
"query": "What is the capital of France?",
"documents": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Horses and cows are both animals"
]
}'
```

Response:

```bash
{
"id": "rerank-fae51b2b664d4ed38f5969b612edff77",
"model": "BAAI/bge-reranker-base",
"usage": {
"total_tokens": 56
},
"results": [
{
"index": 1,
"document": {
"text": "The capital of France is Paris."
},
"relevance_score": 0.99853515625
},
{
"index": 0,
"document": {
"text": "The capital of Brazil is Brasilia."
},
"relevance_score": 0.0005860328674316406
}
]
}
```

#### Extra parameters

The following [pooling parameters](#pooling-params) are supported.

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-rerank-pooling-params
:end-before: end-rerank-pooling-params
```

The following extra parameters are supported:

```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-rerank-extra-params
:end-before: end-rerank-extra-params
```
32 changes: 32 additions & 0 deletions examples/online_serving/cohere_rerank_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
the Cohere SDK: https://github.com/cohere-ai/cohere-python
run: vllm serve BAAI/bge-reranker-base
"""
import cohere

# cohere v1 client
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
rerank_v1_result = co.rerank(
model="BAAI/bge-reranker-base",
query="What is the capital of France?",
documents=[
"The capital of France is Paris", "Reranking is fun!",
"vLLM is an open-source framework for fast AI serving"
])

print(rerank_v1_result)

# or the v2
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")

v2_rerank_result = co2.rerank(
model="BAAI/bge-reranker-base",
query="What is the capital of France?",
documents=[
"The capital of France is Paris", "Reranking is fun!",
"vLLM is an open-source framework for fast AI serving"
])

print(v2_rerank_result)
33 changes: 33 additions & 0 deletions examples/online_serving/jinaai_rerank_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Example of using the OpenAI entrypoint's rerank API which is compatible with
Jina and Cohere https://jina.ai/reranker
run: vllm serve BAAI/bge-reranker-base
"""
import json

import requests

url = "http://127.0.0.1:8000/rerank"

headers = {"accept": "application/json", "Content-Type": "application/json"}

data = {
"model":
"BAAI/bge-reranker-base",
"query":
"What is the capital of France?",
"documents": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.", "Horses and cows are both animals"
]
}
response = requests.post(url, headers=headers, json=data)

# Check the response
if response.status_code == 200:
print("Request successful!")
print(json.dumps(response.json(), indent=2))
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
87 changes: 87 additions & 0 deletions tests/entrypoints/openai/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
import requests

from vllm.entrypoints.openai.protocol import RerankResponse

from ...utils import RemoteOpenAIServer

MODEL_NAME = "BAAI/bge-reranker-base"


@pytest.fixture(scope="module")
def server():
args = ["--enforce-eager", "--max-model-len", "100"]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]

rerank_response = requests.post(server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
})
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())

assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 2
assert rerank.results[0].relevance_score >= 0.9
assert rerank.results[1].relevance_score <= 0.01


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_top_n(server: RemoteOpenAIServer, model_name: str):
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.", "Cross-encoder models are neat"
]

rerank_response = requests.post(server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
"top_n": 2
})
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())

assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 2
assert rerank.results[0].relevance_score >= 0.9
assert rerank.results[1].relevance_score <= 0.01


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):

query = "What is the capital of France?" * 100
documents = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]

rerank_response = requests.post(server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents
})
assert rerank_response.status_code == 400
# Assert just a small fragments of the response
assert "Please reduce the length of the input." in \
rerank_response.text
7 changes: 1 addition & 6 deletions tests/entrypoints/openai/test_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@

@pytest.fixture(scope="module")
def server():
args = [
"--enforce-eager",
# Will be used on tests to compare prompt input length
"--max-model-len",
"100"
]
args = ["--enforce-eager", "--max-model-len", "100"]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
Expand Down
51 changes: 50 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
Expand All @@ -68,6 +69,7 @@
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
Expand Down Expand Up @@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
return request.app.state.openai_serving_scores


def rerank(request: Request) -> Optional[JinaAIServingRerank]:
return request.app.state.jinaai_serving_reranking


def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization

Expand Down Expand Up @@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)


@router.post("/rerank")
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Rerank (Score) API")
generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())

assert_never(generator)


@router.post("/v1/rerank")
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client"
"accordingly. (Note: Conforms to JinaAI rerank API)")

return await do_rerank(request, raw_request)


@router.post("/v2/rerank")
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)


TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
Expand All @@ -512,7 +552,10 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (ScoreRequest, create_score),
"default": (RerankRequest, do_rerank)
},
"rerank": {
"default": (RerankRequest, do_rerank)
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
Expand Down Expand Up @@ -759,6 +802,12 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.jinaai_serving_reranking = JinaAIServingRerank(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
Expand Down
Loading

0 comments on commit 0034b09

Please sign in to comment.