Skip to content

Commit

Permalink
Support suffix for fill-in-the-middle
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Nov 6, 2024
1 parent 2c80c72 commit b9f6005
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
52 changes: 39 additions & 13 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from grpc_reflection.v1alpha import reflection
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.openai.fim import get_fim_encoder_lookup
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import LLMInputs
from vllm.sampling_params import RequestOutputKind, SamplingParams
Expand Down Expand Up @@ -209,6 +210,8 @@ def __init__(
)
self.health_servicer = health_servicer

self.get_fim_encoder = get_fim_encoder_lookup(args.fim)

async def post_init(self) -> None:
self.config = await self.engine.get_model_config()

Expand Down Expand Up @@ -254,7 +257,12 @@ async def Generate(

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
sampling_params, truncate_input_tokens, req.text, tokenizer, context
sampling_params,
truncate_input_tokens,
req.text,
req.suffix,
tokenizer,
context,
)

inputs = LLMInputs(
Expand Down Expand Up @@ -352,6 +360,7 @@ async def GenerateStream( # noqa: PLR0915, C901
sampling_params,
truncate_input_tokens,
request.request.text,
request.request.suffix,
tokenizer,
context,
)
Expand Down Expand Up @@ -782,30 +791,47 @@ def _convert_tokens( # noqa: PLR0913
)
token_infos.append(token_info)

async def _validate_prompt_and_tokenize(
async def _validate_prompt_and_tokenize( # noqa: PLR0913
self,
sampling_params: SamplingParams,
truncate_input_tokens: int | None,
prompt: str,
suffix: str | None,
tokenizer: AnyTokenizer,
context: ServicerContext,
) -> tuple[list[int], bool]:
assert self.config is not None

max_model_len = self.config.max_model_len

tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": ADD_SPECIAL_TOKENS}
if truncate_input_tokens is not None:
tokenizer_kwargs.update(
{
"truncation": True,
"max_length": truncate_input_tokens,
}
)
if suffix:
if not (get_fim_encoder := self.get_fim_encoder):
await context.abort(
StatusCode.INVALID_ARGUMENT,
"fim support must be enabled to use suffix",
)
if truncate_input_tokens is not None:
await context.abort(
StatusCode.INVALID_ARGUMENT,
"truncate_input_tokens cannot be used with suffix",
)
fim_encoder = get_fim_encoder(tokenizer)
input_ids = fim_encoder.encode_with_suffix(prefix=prompt, suffix=suffix)
else:
tokenizer_kwargs: dict[str, Any] = {
"add_special_tokens": ADD_SPECIAL_TOKENS
}
if truncate_input_tokens is not None:
tokenizer_kwargs.update(
{
"truncation": True,
"max_length": truncate_input_tokens,
}
)

input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
token_num = len(input_ids)

max_model_len = self.config.max_model_len

try:
validate_input(sampling_params, token_num, max_model_len)
except ValueError as tgis_validation_error:
Expand Down
2 changes: 2 additions & 0 deletions src/vllm_tgis_adapter/grpc/pb/generation.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ message BatchedGenerationResponse {

message GenerationRequest {
string text = 2;
/// Optional, for fill-in-middle
string suffix = 3;
}

message GenerationResponse {
Expand Down

0 comments on commit b9f6005

Please sign in to comment.