forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend] Rerank API (Jina- and Cohere-compatible API) (vllm-project…
…#12376) Signed-off-by: Kyle Mistele <[email protected]>
- Loading branch information
Showing
9 changed files
with
552 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
Example of using the OpenAI entrypoint's rerank API which is compatible with | ||
the Cohere SDK: https://github.com/cohere-ai/cohere-python | ||
run: vllm serve BAAI/bge-reranker-base | ||
""" | ||
import cohere | ||
|
||
# cohere v1 client | ||
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") | ||
rerank_v1_result = co.rerank( | ||
model="BAAI/bge-reranker-base", | ||
query="What is the capital of France?", | ||
documents=[ | ||
"The capital of France is Paris", "Reranking is fun!", | ||
"vLLM is an open-source framework for fast AI serving" | ||
]) | ||
|
||
print(rerank_v1_result) | ||
|
||
# or the v2 | ||
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") | ||
|
||
v2_rerank_result = co2.rerank( | ||
model="BAAI/bge-reranker-base", | ||
query="What is the capital of France?", | ||
documents=[ | ||
"The capital of France is Paris", "Reranking is fun!", | ||
"vLLM is an open-source framework for fast AI serving" | ||
]) | ||
|
||
print(v2_rerank_result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
Example of using the OpenAI entrypoint's rerank API which is compatible with | ||
Jina and Cohere https://jina.ai/reranker | ||
run: vllm serve BAAI/bge-reranker-base | ||
""" | ||
import json | ||
|
||
import requests | ||
|
||
url = "http://127.0.0.1:8000/rerank" | ||
|
||
headers = {"accept": "application/json", "Content-Type": "application/json"} | ||
|
||
data = { | ||
"model": | ||
"BAAI/bge-reranker-base", | ||
"query": | ||
"What is the capital of France?", | ||
"documents": [ | ||
"The capital of Brazil is Brasilia.", | ||
"The capital of France is Paris.", "Horses and cows are both animals" | ||
] | ||
} | ||
response = requests.post(url, headers=headers, json=data) | ||
|
||
# Check the response | ||
if response.status_code == 200: | ||
print("Request successful!") | ||
print(json.dumps(response.json(), indent=2)) | ||
else: | ||
print(f"Request failed with status code: {response.status_code}") | ||
print(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytest | ||
import requests | ||
|
||
from vllm.entrypoints.openai.protocol import RerankResponse | ||
|
||
from ...utils import RemoteOpenAIServer | ||
|
||
MODEL_NAME = "BAAI/bge-reranker-base" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def server(): | ||
args = ["--enforce-eager", "--max-model-len", "100"] | ||
|
||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: | ||
yield remote_server | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): | ||
query = "What is the capital of France?" | ||
documents = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
|
||
rerank_response = requests.post(server.url_for("rerank"), | ||
json={ | ||
"model": model_name, | ||
"query": query, | ||
"documents": documents, | ||
}) | ||
rerank_response.raise_for_status() | ||
rerank = RerankResponse.model_validate(rerank_response.json()) | ||
|
||
assert rerank.id is not None | ||
assert rerank.results is not None | ||
assert len(rerank.results) == 2 | ||
assert rerank.results[0].relevance_score >= 0.9 | ||
assert rerank.results[1].relevance_score <= 0.01 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
def test_top_n(server: RemoteOpenAIServer, model_name: str): | ||
query = "What is the capital of France?" | ||
documents = [ | ||
"The capital of Brazil is Brasilia.", | ||
"The capital of France is Paris.", "Cross-encoder models are neat" | ||
] | ||
|
||
rerank_response = requests.post(server.url_for("rerank"), | ||
json={ | ||
"model": model_name, | ||
"query": query, | ||
"documents": documents, | ||
"top_n": 2 | ||
}) | ||
rerank_response.raise_for_status() | ||
rerank = RerankResponse.model_validate(rerank_response.json()) | ||
|
||
assert rerank.id is not None | ||
assert rerank.results is not None | ||
assert len(rerank.results) == 2 | ||
assert rerank.results[0].relevance_score >= 0.9 | ||
assert rerank.results[1].relevance_score <= 0.01 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): | ||
|
||
query = "What is the capital of France?" * 100 | ||
documents = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
|
||
rerank_response = requests.post(server.url_for("rerank"), | ||
json={ | ||
"model": model_name, | ||
"query": query, | ||
"documents": documents | ||
}) | ||
assert rerank_response.status_code == 400 | ||
# Assert just a small fragments of the response | ||
assert "Please reduce the length of the input." in \ | ||
rerank_response.text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.