Skip to content

Commit

Permalink
✨ add tokenization with truncation, offset support
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <[email protected]>
  • Loading branch information
fialhocoelho authored and prashantgupta24 committed Jun 26, 2024
1 parent 864d262 commit b32eff3
Showing 1 changed file with 52 additions and 18 deletions.
70 changes: 52 additions & 18 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,32 +733,66 @@ async def _validate_prompt_and_tokenize(

@log_rpc_handler_errors
async def Tokenize(
self, request: BatchedTokenizeRequest, context: ServicerContext
self,
request: BatchedTokenizeRequest,
context: ServicerContext, # noqa: ARG002
) -> BatchedTokenizeResponse:
service_metrics.count_tokenization_request(request)
# TODO implement these
if request.return_offsets:
await context.abort(
StatusCode.INVALID_ARGUMENT, "return_offsets not yet supported"
)
if request.truncate_input_tokens:
await context.abort(
StatusCode.INVALID_ARGUMENT, "truncate_input_tokens not yet supported"
)
"""Handle tokenization requests by tokenizing input texts \
and returning tokenized results.
If request.truncate_input_tokens is
provided, the tokenization will contain the truncated results.
Args:
----
request (BatchedTokenizeRequest): The tokenization request
containing texts to be tokenized.
context (ServicerContext): The context for the RPC call.
Returns:
-------
BatchedTokenizeResponse: The response containing the
tokenized results.
"""
# Log the incoming tokenization request for metrics
service_metrics.observe_tokenization_request(request)

responses: list[TokenizeResponse] = []

# TODO maybe parallelize, also move convert_ids_to_tokens
# into the other threads
# TODO: maybe parallelize, also move convert_ids_to_tokens into the
# other threads
for req in request.requests:
token_ids = await self.tokenizer_group.encode_async(req.text)
batch_encoding = self.tokenizer.encode_plus(
text=req.text, return_offsets_mapping=request.return_offsets
)

# Tokenize the input text
token_ids = batch_encoding.input_ids
token_count = len(token_ids)

if 0 < request.truncate_input_tokens < token_count:
token_count = request.truncate_input_tokens

# Initialize Tokens from ids
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
offsets = None

if request.return_offsets:
offsets = [
{"start": start, "end": end}
for start, end in batch_encoding.offset_mapping
if start is not None and end is not None
]
# Truncate offset list if request.truncate_input_tokens
offsets = offsets[-token_count:]

tokens = tokens[-token_count:] if request.return_tokens else None

responses.append(
TokenizeResponse(
token_count=len(token_ids),
tokens=self.tokenizer.convert_ids_to_tokens(token_ids)
if request.return_tokens
else None,
token_count=token_count, tokens=tokens, offsets=offsets
)
)

Expand Down

0 comments on commit b32eff3

Please sign in to comment.