diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 64c04ea..88a66cb 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -15,6 +15,7 @@ from grpc_reflection.v1alpha import reflection from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.entrypoints.openai.fim import get_fim_encoder_lookup from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.inputs import LLMInputs from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -209,6 +210,8 @@ def __init__( ) self.health_servicer = health_servicer + self.get_fim_encoder = get_fim_encoder_lookup(args.fim) + async def post_init(self) -> None: self.config = await self.engine.get_model_config() @@ -254,7 +257,12 @@ async def Generate( for i, req in enumerate(request.requests): input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize( - sampling_params, truncate_input_tokens, req.text, tokenizer, context + sampling_params, + truncate_input_tokens, + req.text, + req.suffix, + tokenizer, + context, ) inputs = LLMInputs( @@ -352,6 +360,7 @@ async def GenerateStream( # noqa: PLR0915, C901 sampling_params, truncate_input_tokens, request.request.text, + request.request.suffix, tokenizer, context, ) @@ -782,30 +791,47 @@ def _convert_tokens( # noqa: PLR0913 ) token_infos.append(token_info) - async def _validate_prompt_and_tokenize( + async def _validate_prompt_and_tokenize( # noqa: PLR0913 self, sampling_params: SamplingParams, truncate_input_tokens: int | None, prompt: str, + suffix: str | None, tokenizer: AnyTokenizer, context: ServicerContext, ) -> tuple[list[int], bool]: assert self.config is not None - max_model_len = self.config.max_model_len - - tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": ADD_SPECIAL_TOKENS} - if truncate_input_tokens is not None: - tokenizer_kwargs.update( - { - "truncation": True, - "max_length": truncate_input_tokens, - } - ) + if suffix: + if not (get_fim_encoder := self.get_fim_encoder): + await context.abort( + StatusCode.INVALID_ARGUMENT, + "fim support must be enabled to use suffix", + ) + if truncate_input_tokens is not None: + await context.abort( + StatusCode.INVALID_ARGUMENT, + "truncate_input_tokens cannot be used with suffix", + ) + fim_encoder = get_fim_encoder(tokenizer) + input_ids = fim_encoder.encode_with_suffix(prefix=prompt, suffix=suffix) + else: + tokenizer_kwargs: dict[str, Any] = { + "add_special_tokens": ADD_SPECIAL_TOKENS + } + if truncate_input_tokens is not None: + tokenizer_kwargs.update( + { + "truncation": True, + "max_length": truncate_input_tokens, + } + ) - input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids + input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids token_num = len(input_ids) + max_model_len = self.config.max_model_len + try: validate_input(sampling_params, token_num, max_model_len) except ValueError as tgis_validation_error: diff --git a/src/vllm_tgis_adapter/grpc/pb/generation.proto b/src/vllm_tgis_adapter/grpc/pb/generation.proto index dac9237..000f9d7 100644 --- a/src/vllm_tgis_adapter/grpc/pb/generation.proto +++ b/src/vllm_tgis_adapter/grpc/pb/generation.proto @@ -51,6 +51,8 @@ message BatchedGenerationResponse { message GenerationRequest { string text = 2; + /// Optional, for fill-in-middle + string suffix = 3; } message GenerationResponse {