Skip to content

Commit

Permalink
fix: add_special_tokens in tokenize (#144)
Browse files Browse the repository at this point in the history
#66 made add_special_tokens true by default but its behaviour isn't replicated in /tokenize resulting in a different token count if ADD_SPECIAL_TOKENS is false. This PR fixes that by passing it in /tokenize and adds a test for the tokenize method.

I can follow this up with another test that compares the token count between the methods if required but otherwise this closes #141.
  • Loading branch information
rafvasq authored Oct 1, 2024
1 parent 896db8b commit 7a3301f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,9 @@ async def Tokenize(
# other threads
for req in request.requests:
batch_encoding = tokenizer.encode_plus(
text=req.text, return_offsets_mapping=request.return_offsets
text=req.text,
return_offsets_mapping=request.return_offsets,
add_special_tokens=ADD_SPECIAL_TOKENS,
)

# Tokenize the input text
Expand Down
8 changes: 8 additions & 0 deletions tests/test_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def test_generation_request(grpc_client):
assert response.stop_reason is not None


def test_tokenize_request(grpc_client):
response_tokenize = grpc_client.make_request_tokenize(
text="Please answer the following question.\nhow far is Paris from New York?",
)

assert response_tokenize.token_count


def test_generation_request_stream(grpc_client):
streaming_response = grpc_client.make_request_stream(
"The answer to life the universe and everything is ",
Expand Down
27 changes: 27 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
BatchedGenerationRequest,
BatchedTokenizeRequest,
GenerationRequest,
ModelInfoRequest,
Parameters,
SingleGenerationRequest,
StoppingCriteria,
TokenizeRequest,
)
from vllm_tgis_adapter.grpc.pb.generation_pb2_grpc import GenerationServiceStub

Expand All @@ -25,6 +27,7 @@
from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
GenerationResponse,
ModelInfoResponse,
TokenizeResponse,
)

_T = TypeVar("_T")
Expand Down Expand Up @@ -173,6 +176,30 @@ def make_request_stream(
except grpc._channel._MultiThreadedRendezvous as exc: # noqa: SLF001
raise RuntimeError(exc.details()) from exc

def make_request_tokenize(
self,
text: str | list[str],
model_id: str | None = None,
adapter_id: str | None = None,
) -> TokenizeResponse | Sequence[TokenizeResponse]:
if single_request := isinstance(text, str):
text = [text]

request = BatchedTokenizeRequest(
model_id=model_id,
requests=[TokenizeRequest(text=piece) for piece in text],
adapter_id=adapter_id,
)

response = self.generation_service_stub.Tokenize(
request=request,
)

if single_request:
return response.responses[0]

return response.responses

def __enter__(self): # noqa: D105
return self

Expand Down

0 comments on commit 7a3301f

Please sign in to comment.