From 79cd0d4f4dd1ea012c6e77d41141e66338accd0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Mon, 16 Sep 2024 16:21:09 +0200 Subject: [PATCH] grpc_server: use x-correlation-id as request-id when possible --- src/vllm_tgis_adapter/grpc/grpc_server.py | 13 +++++++++-- tests/test_grpc_server.py | 27 +++++++++++++++++++++++ tests/utils.py | 2 ++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 408b6b99..1ee9bccc 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -518,8 +518,17 @@ def _convert_output( # noqa: PLR0913 return response @staticmethod - def request_id(context: ServicerContext) -> str: # noqa: ARG004 - return uuid.uuid4().hex + def request_id(context: ServicerContext) -> str: + metadata = context.invocation_metadata() + if not metadata: + return uuid.uuid4().hex + + correlation_id = dict(metadata).get("x-correlation-id") + + if not correlation_id: + return uuid.uuid4().hex + + return correlation_id async def _validate_and_convert_params( self, diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py index 653070a0..e64431eb 100644 --- a/tests/test_grpc_server.py +++ b/tests/test_grpc_server.py @@ -55,3 +55,30 @@ def test_lora_request(grpc_client, lora_adapter_name): response = grpc_client.make_request("hello", adapter_id=lora_adapter_name) assert response.text + + +def test_request_id(grpc_client, mocker): + from vllm_tgis_adapter.grpc.grpc_server import TextGenerationService, uuid + + spy = mocker.spy(TextGenerationService, "request_id") + response = grpc_client.make_request( + "The answer to life the universe and everything is ", + metadata=[("x-correlation-id", "dummy-correlation-id")], + ) + assert response.text + + spy.assert_called_once() + assert spy.spy_return == "dummy-correlation-id" + + spy.reset_mock() + + request_id = uuid.uuid4() + mocker.patch.object(uuid, "uuid4", return_value=request_id) + + response = grpc_client.make_request( + "The answer to life the universe and everything is ", + ) + assert response.text + + spy.assert_called_once() + assert spy.spy_return == request_id.hex diff --git a/tests/utils.py b/tests/utils.py index 969b7281..e848cb38 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -126,6 +126,7 @@ def make_request( model_id: str | None = None, max_new_tokens: int = 10, adapter_id: str | None = None, + metadata: list[tuple[str, str]] | None = None, ) -> GenerationResponse | Sequence[GenerationResponse]: # assert model_id # FIXME: is model_id required? @@ -143,6 +144,7 @@ def make_request( response = self.generation_service_stub.Generate( request=request, + metadata=metadata, ) if single_request: