diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 15a068f035494..58b47ee593060 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -9,6 +9,7 @@ # using Ray for overall ease of process management, parallel requests, # and debugging. import ray +import requests import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -1374,5 +1375,53 @@ async def test_long_seed(client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_tokenize(server, client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3] + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + + for add_special in [False, True]: + prompt = "This is a test prompt." + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt + }) + response.raise_for_status() + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_detokenize(server, client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3] + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + + prompt = "This is a test prompt." + tokens = tokenizer.encode(prompt, add_special_tokens=False) + + response = requests.post(base_url + "detokenize", + json={ + "model": model_name, + "tokens": tokens + }) + response.raise_for_status() + assert response.json() == {"prompt": prompt} + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ea6275920c79d..a708176c254ec 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -19,10 +19,17 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, - EmbeddingRequest, ErrorResponse) + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, ErrorResponse, + TokenizeRequest, + TokenizeResponse) +# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -85,6 +92,28 @@ async def health() -> Response: return Response(status_code=200) +@app.post("/tokenize") +async def tokenize(request: TokenizeRequest): + generator = await openai_serving_completion.create_tokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + assert isinstance(generator, TokenizeResponse) + return JSONResponse(content=generator.model_dump()) + + +@app.post("/detokenize") +async def detokenize(request: DetokenizeRequest): + generator = await openai_serving_completion.create_detokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + assert isinstance(generator, DetokenizeResponse) + return JSONResponse(content=generator.model_dump()) + + @app.get("/v1/models") async def show_available_models(): models = await openai_serving_chat.show_available_models() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b57d79859aec5..7fb1af158531d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -699,3 +699,24 @@ class BatchRequestOutput(OpenAIBaseModel): # For requests that failed with a non-HTTP error, this will contain more # information on the cause of the failure. error: Optional[Any] + + +class TokenizeRequest(OpenAIBaseModel): + model: str + prompt: str + add_special_tokens: bool = Field(default=True) + + +class TokenizeResponse(OpenAIBaseModel): + tokens: List[int] + count: int + max_model_len: int + + +class DetokenizeRequest(OpenAIBaseModel): + model: str + tokens: List[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c775fa6daa739..8741893c92716 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -16,7 +16,11 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - UsageInfo) + DetokenizeRequest, + DetokenizeResponse, + TokenizeRequest, + TokenizeResponse, UsageInfo) +# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger @@ -442,3 +446,29 @@ def _create_completion_logprobs( tokens=out_tokens, top_logprobs=out_top_logprobs, ) + + async def create_tokenize(self, + request: TokenizeRequest) -> TokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, + prompt=request.prompt, + add_special_tokens=request.add_special_tokens) + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, request: DetokenizeRequest) -> DetokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, prompt_ids=request.tokens) + + return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 6b5a62efc7f20..84e4127725bb7 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -10,9 +10,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, + DetokenizeRequest, EmbeddingRequest, ErrorResponse, ModelCard, ModelList, - ModelPermission) + ModelPermission, TokenizeRequest) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import Logprob @@ -99,8 +100,9 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[CompletionRequest, ChatCompletionRequest, - EmbeddingRequest] + self, request: Union[ChatCompletionRequest, CompletionRequest, + DetokenizeRequest, EmbeddingRequest, + TokenizeRequest] ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -126,7 +128,8 @@ def _maybe_get_lora( def _validate_prompt_and_tokenize( self, request: Union[ChatCompletionRequest, CompletionRequest, - EmbeddingRequest], + DetokenizeRequest, EmbeddingRequest, + TokenizeRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, truncate_prompt_tokens: Optional[Annotated[int, @@ -174,6 +177,11 @@ def _validate_prompt_and_tokenize( f"generation. Please reduce the length of the input.", ) return input_ids, input_text + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance(request, (TokenizeRequest, DetokenizeRequest)): + return input_ids, input_text + if request.max_tokens is None: if token_num >= self.max_model_len: raise ValueError(