Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Frontend] Add tokenize/detokenize endpoints (vllm-project#5054)
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent 1d1929b commit 5095252
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 6 deletions.
49 changes: 49 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -1374,5 +1375,53 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize(server, client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(server, client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)

response = requests.post(base_url + "detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}


if __name__ == "__main__":
pytest.main([__file__])
31 changes: 30 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingRequest, ErrorResponse)
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
Expand Down Expand Up @@ -85,6 +92,28 @@ async def health() -> Response:
return Response(status_code=200)


@app.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump())


@app.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump())


@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
Expand Down
21 changes: 21 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,24 @@ class BatchRequestOutput(OpenAIBaseModel):
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]


class TokenizeRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)


class TokenizeResponse(OpenAIBaseModel):
tokens: List[int]
count: int
max_model_len: int


class DetokenizeRequest(OpenAIBaseModel):
model: str
tokens: List[int]


class DetokenizeResponse(OpenAIBaseModel):
prompt: str
32 changes: 31 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
Expand Down Expand Up @@ -442,3 +446,29 @@ def _create_completion_logprobs(
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)

async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)

return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)

async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)

return DetokenizeResponse(prompt=input_text)
16 changes: 12 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission)
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
Expand Down Expand Up @@ -99,8 +100,9 @@ def create_streaming_error_response(
return json_str

async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
Expand All @@ -126,7 +128,8 @@ def _maybe_get_lora(
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
EmbeddingRequest],
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Expand Down Expand Up @@ -174,6 +177,11 @@ def _validate_prompt_and_tokenize(
f"generation. Please reduce the length of the input.", )
return input_ids, input_text

# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text

if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
Expand Down

0 comments on commit 5095252

Please sign in to comment.