From e74ee3964632795542d2f79ecdadc4702b5856ea Mon Sep 17 00:00:00 2001 From: Robert Samoilescu Date: Fri, 12 Jul 2024 12:00:35 +0100 Subject: [PATCH 1/6] Included streaming support for grpc prom server interceptor. --- mlserver/grpc/interceptors.py | 245 ++++++++++++++++++++-------------- 1 file changed, 147 insertions(+), 98 deletions(-) diff --git a/mlserver/grpc/interceptors.py b/mlserver/grpc/interceptors.py index c1d6e2df8..ccde3509c 100644 --- a/mlserver/grpc/interceptors.py +++ b/mlserver/grpc/interceptors.py @@ -50,43 +50,10 @@ async def intercept_service( metrics_wrapper = partial(self._metrics_wrapper, method_call) return self._interceptor._wrap_rpc_behavior(handler, metrics_wrapper) - def _compute_status_code(self, servicer_context: ServicerContext) -> StatusCode: - """ - This method is mostly copied from `py-grpc-prometheus`, with a couple - minor changes to avoid using private APIs from ServicerContext which - don't exist anymore in `grpc.aio`. - To see the original implementation, please check: - - https://github.com/lchenn/py-grpc-prometheus/blob/eb9dee1f0a4e57cef220193ee48021dc9a9f3d82/py_grpc_prometheus/prometheus_server_interceptor.py#L127-L134 - """ - # Backwards compatibility for non-aio. - # TODO: It's not clear yet how to check whether the context has been - # cancelled with aio. - if hasattr(servicer_context, "_state"): - if servicer_context._state.client == "cancelled": - return StatusCode.CANCELLED - - if not hasattr(servicer_context, "code"): - return StatusCode.OK - - code = servicer_context.code() - if code is None: - return StatusCode.OK - - # NOTE: With gRPC AIO, the `code` can be a plain integer that needs to - # be converted to an actual `StatusCode` entry - if isinstance(code, int): - if code not in self._status_codes: - return StatusCode.UNKNOWN - - return self._status_codes[code] - - return code - def _metrics_wrapper( self, method_call: Tuple[str, str, str], - old_handler: RpcMethodHandler, + behavior: RpcMethodHandler, request_streaming: bool, response_streaming: bool, ): @@ -99,60 +66,111 @@ def _metrics_wrapper( """ grpc_service_name, grpc_method_name, _ = method_call - async def _new_handler(request_or_iterator, servicer_context: ServicerContext): - response_or_iterator = None + async def new_behavior(request, servicer_context): + response = None try: start = default_timer() grpc_type = grpc_utils.get_method_type( request_streaming, response_streaming ) + try: - if request_streaming: - request_or_iterator = grpc_utils.wrap_iterator_inc_counter( - request_or_iterator, - self._interceptor._metrics[ - "grpc_server_stream_msg_received" - ], - grpc_type, - grpc_service_name, - grpc_method_name, + self._interceptor._metrics["grpc_server_started_counter"].labels( + grpc_type=grpc_type, + grpc_service=grpc_service_name, + grpc_method=grpc_method_name, + ).inc() + + # Invoke the original rpc behavior. + # NOTE: This is the main change required with respect to + # the original implementation in `py-grpc-prometheus`. + response = await behavior(request, servicer_context) + self._interceptor.increase_grpc_server_handled_total_counter( + grpc_type, + grpc_service_name, + grpc_method_name, + self._compute_status_code(servicer_context).name, + ) + return response + + except RpcError as e: + self._interceptor.increase_grpc_server_handled_total_counter( + grpc_type, + grpc_service_name, + grpc_method_name, + self._interceptor._compute_error_code(e).name, + ) + raise e + + finally: + if self._interceptor._legacy: + self._interceptor._metrics[ + "legacy_grpc_server_handled_latency_seconds" + ].labels( + grpc_type=grpc_type, + grpc_service=grpc_service_name, + grpc_method=grpc_method_name, + ).observe( + max(default_timer() - start, 0) ) - else: + elif self._interceptor._enable_handling_time_histogram: self._interceptor._metrics[ - "grpc_server_started_counter" + "grpc_server_handled_histogram" ].labels( grpc_type=grpc_type, grpc_service=grpc_service_name, grpc_method=grpc_method_name, - ).inc() + ).observe( + max(default_timer() - start, 0) + ) + except Exception as e: # pylint: disable=broad-except + # Allow user to skip the exceptions in order to maintain + # the basic functionality in the server + # The logging function in exception can be toggled with log_exceptions + # in order to suppress the noise in logging + if self._interceptor._skip_exceptions: + if self._interceptor._log_exceptions: + logger.error(e) - # Invoke the original rpc behavior. - # NOTE: This is the main change required with respect to - # the original implementation in `py-grpc-prometheus`. - response_or_iterator = await old_handler( - request_or_iterator, servicer_context + if response is None: + return response + + return await behavior(request, servicer_context) + raise e + + async def new_behavior_stream( + request_async_iterator, servicer_context: ServicerContext + ): + response_async_iterator = None + try: + grpc_type = grpc_utils.get_method_type( + request_streaming, response_streaming + ) + try: + request_async_iterator = wrap_async_iterator_inc_counter( + request_async_iterator, + self._interceptor._metrics["grpc_server_stream_msg_received"], + grpc_type, + grpc_service_name, + grpc_method_name, ) - if response_streaming: - sent_metric = self._interceptor._metrics[ - "grpc_server_stream_msg_sent" - ] - response_or_iterator = grpc_utils.wrap_iterator_inc_counter( - response_or_iterator, - sent_metric, - grpc_type, - grpc_service_name, - grpc_method_name, - ) + # wrap the original behavior with the metrics + sent_metric = self._interceptor._metrics[ + "grpc_server_stream_msg_sent" + ] + response_async_iterator = wrap_async_iterator_inc_counter( + behavior(request_async_iterator, servicer_context), + sent_metric, + grpc_type, + grpc_service_name, + grpc_method_name, + ) + + # invoke the original rpc behavior + async for item in response_async_iterator: + yield item - else: - self._interceptor.increase_grpc_server_handled_total_counter( - grpc_type, - grpc_service_name, - grpc_method_name, - self._compute_status_code(servicer_context).name, - ) - return response_or_iterator except RpcError as e: self._interceptor.increase_grpc_server_handled_total_counter( grpc_type, @@ -162,28 +180,6 @@ async def _new_handler(request_or_iterator, servicer_context: ServicerContext): ) raise e - finally: - if not response_streaming: - if self._interceptor._legacy: - self._interceptor._metrics[ - "legacy_grpc_server_handled_latency_seconds" - ].labels( - grpc_type=grpc_type, - grpc_service=grpc_service_name, - grpc_method=grpc_method_name, - ).observe( - max(default_timer() - start, 0) - ) - elif self._interceptor._enable_handling_time_histogram: - self._interceptor._metrics[ - "grpc_server_handled_histogram" - ].labels( - grpc_type=grpc_type, - grpc_service=grpc_service_name, - grpc_method=grpc_method_name, - ).observe( - max(default_timer() - start, 0) - ) except Exception as e: # pylint: disable=broad-except # Allow user to skip the exceptions in order to maintain # the basic functionality in the server @@ -192,9 +188,62 @@ async def _new_handler(request_or_iterator, servicer_context: ServicerContext): if self._interceptor._skip_exceptions: if self._interceptor._log_exceptions: logger.error(e) - if response_or_iterator is None: - return response_or_iterator - return old_handler(request_or_iterator, servicer_context) + + if response_async_iterator is not None: + async for item in behavior( + request_async_iterator, servicer_context + ): + yield item raise e - return _new_handler + if request_streaming and response_streaming: + return new_behavior_stream + + return new_behavior + + def _compute_status_code(self, servicer_context: ServicerContext) -> StatusCode: + """ + This method is mostly copied from `py-grpc-prometheus`, with a couple + minor changes to avoid using private APIs from ServicerContext which + don't exist anymore in `grpc.aio`. + To see the original implementation, please check: + + https://github.com/lchenn/py-grpc-prometheus/blob/eb9dee1f0a4e57cef220193ee48021dc9a9f3d82/py_grpc_prometheus/prometheus_server_interceptor.py#L127-L134 + """ + # Backwards compatibility for non-aio. + # TODO: It's not clear yet how to check whether the context has been + # cancelled with aio. + if hasattr(servicer_context, "_state"): + if servicer_context._state.client == "cancelled": + return StatusCode.CANCELLED + + if not hasattr(servicer_context, "code"): + return StatusCode.OK + + code = servicer_context.code() + if code is None: + return StatusCode.OK + + # NOTE: With gRPC AIO, the `code` can be a plain integer that needs to + # be converted to an actual `StatusCode` entry + if isinstance(code, int): + if code not in self._status_codes: + return StatusCode.UNKNOWN + + return self._status_codes[code] + + return code + + +async def wrap_async_iterator_inc_counter( + iterator, counter, grpc_type, grpc_service_name, grpc_method_name +): + """Wraps an async iterator and collect metrics.""" + + async for item in iterator: + counter.labels( + grpc_type=grpc_type, + grpc_service=grpc_service_name, + grpc_method=grpc_method_name, + ).inc() + yield item From 353a8c585a2a2715e15884eba8a2241273763625 Mon Sep 17 00:00:00 2001 From: Robert Samoilescu Date: Fri, 12 Jul 2024 15:09:24 +0100 Subject: [PATCH 2/6] Included prometheus interceptor tests. --- mlserver/grpc/server.py | 10 +-- tests/grpc/test_interceptor.py | 129 ++++++++++++++++++++++++++++ tests/testdata/settings-stream.json | 1 - 3 files changed, 134 insertions(+), 6 deletions(-) create mode 100644 tests/grpc/test_interceptor.py diff --git a/mlserver/grpc/server.py b/mlserver/grpc/server.py index cc5a84529..2d8fef3e0 100644 --- a/mlserver/grpc/server.py +++ b/mlserver/grpc/server.py @@ -38,14 +38,14 @@ def _create_server(self): self._model_repository_handlers ) - interceptors = [] + self._interceptors = [] if self._settings.debug: # If debug, enable access logs - interceptors = [LoggingInterceptor()] + self._interceptors = [LoggingInterceptor()] if self._settings.metrics_endpoint: - interceptors.append( + self._interceptors.append( PromServerInterceptor(enable_handling_time_histogram=True) ) @@ -62,7 +62,7 @@ def _create_server(self): ) ) - interceptors.append( + self._interceptors.append( aio_server_interceptor( tracer_provider=tracer_provider, filter_=excluded_urls ) @@ -70,7 +70,7 @@ def _create_server(self): self._server = aio.server( ThreadPoolExecutor(max_workers=DefaultGrpcWorkers), - interceptors=tuple(interceptors), + interceptors=tuple(self._interceptors), options=self._get_options(), ) diff --git a/tests/grpc/test_interceptor.py b/tests/grpc/test_interceptor.py new file mode 100644 index 000000000..2ef39d5e0 --- /dev/null +++ b/tests/grpc/test_interceptor.py @@ -0,0 +1,129 @@ +import pytest +from pytest_lazyfixture import lazy_fixture + +from typing import AsyncGenerator + +from grpc import StatusCode +from mlserver.grpc.interceptors import PromServerInterceptor +from mlserver.codecs import StringCodec +from mlserver.grpc import converters +from mlserver.grpc.server import GRPCServer +from mlserver.grpc.dataplane_pb2_grpc import GRPCInferenceServiceStub +from mlserver.grpc import dataplane_pb2 as pb + + +@pytest.mark.parametrize("sum_model", [lazy_fixture("text_model")]) +@pytest.mark.parametrize("sum_model_settings", [lazy_fixture("text_model_settings")]) +async def test_prometheus_unary_unary( + grpc_server: GRPCServer, + inference_service_stub: AsyncGenerator[GRPCInferenceServiceStub, None], + model_generate_request: pb.ModelInferRequest, +): + # send 10 requests + num_requests = 10 + for _ in range(num_requests): + _ = await inference_service_stub.ModelInfer(model_generate_request) + + grpc_type = "UNARY" + grpc_service_name = "inference.GRPCInferenceService" + grpc_method_name = "ModelInfer" + prom_interceptor = [ + interceptor + for interceptor in grpc_server._interceptors + if isinstance(interceptor, PromServerInterceptor) + ][0] + + # get the number of requests intercepted + counted_requests = ( + prom_interceptor._interceptor._metrics["grpc_server_started_counter"] + .labels( + grpc_type, + grpc_service_name, + grpc_method_name, + ) + ._value.get() + ) + + # get the number of ok responses intercepted + counted_responses = ( + prom_interceptor._interceptor._grpc_server_handled_total_counter.labels( + grpc_type, + grpc_service_name, + grpc_method_name, + StatusCode.OK.name, + )._value.get() + ) + + assert int(counted_requests) == num_requests + assert int(counted_requests) == int(counted_responses) + + +@pytest.mark.parametrize("settings", [lazy_fixture("settings_stream")]) +@pytest.mark.parametrize("sum_model", [lazy_fixture("text_stream_model")]) +@pytest.mark.parametrize("model_name", ["text-stream-model"]) +@pytest.mark.parametrize( + "sum_model_settings", [lazy_fixture("text_stream_model_settings")] +) +async def test_prometheus_stream_stream( + grpc_server: GRPCServer, + inference_service_stub: AsyncGenerator[GRPCInferenceServiceStub, None], + model_generate_request: pb.ModelInferRequest, + model_name: str, +): + model_generate_request.model_name = model_name + + async def get_stream_request(request): + yield request + + # send 10 requests + num_requests = 1 + for _ in range(num_requests): + _ = [ + _ + async for _ in inference_service_stub.ModelStreamInfer( + get_stream_request(model_generate_request) + ) + ] + + grpc_type = "BIDI_STREAMING" + grpc_service_name = "inference.GRPCInferenceService" + grpc_method_name = "ModelStreamInfer" + prom_interceptor = [ + interceptor + for interceptor in grpc_server._interceptors + if isinstance(interceptor, PromServerInterceptor) + ][0] + + # get the number of requests intercepted + counted_requests = ( + prom_interceptor._interceptor._metrics["grpc_server_stream_msg_received"] + .labels( + grpc_type, + grpc_service_name, + grpc_method_name, + ) + ._value.get() + ) + + # get the number of ok responses intercepted + counted_responses = ( + prom_interceptor._interceptor._metrics["grpc_server_stream_msg_sent"] + .labels( + grpc_type, + grpc_service_name, + grpc_method_name, + ) + ._value.get() + ) + + inference_request_g = converters.ModelInferRequestConverter.to_types( + model_generate_request + ) + + # we count the number of words because + # each word is gonna be streamed back + request_text = StringCodec.decode_input(inference_request_g.inputs[0])[0] + num_words = len(request_text.split()) + + assert int(counted_requests) == num_requests + assert int(counted_requests) * num_words == int(counted_responses) diff --git a/tests/testdata/settings-stream.json b/tests/testdata/settings-stream.json index 6727d5b59..809cb452b 100644 --- a/tests/testdata/settings-stream.json +++ b/tests/testdata/settings-stream.json @@ -3,7 +3,6 @@ "host": "127.0.0.1", "parallel_workers": 0, "gzip_enabled": false, - "metrics_endpoint": null, "cors_settings": { "allow_origins": ["*"] } From f4066f2cbe41c8ec447c395587860cbc8a6b2703 Mon Sep 17 00:00:00 2001 From: Robert Samoilescu Date: Fri, 12 Jul 2024 15:35:41 +0100 Subject: [PATCH 3/6] Updated docs. --- docs/examples/streaming/README.ipynb | 57 ++++++++++++++++++--------- docs/examples/streaming/README.md | 8 +--- docs/examples/streaming/settings.json | 3 +- docs/examples/streaming/text_model.py | 14 +------ docs/user-guide/streaming.md | 1 - 5 files changed, 42 insertions(+), 41 deletions(-) diff --git a/docs/examples/streaming/README.ipynb b/docs/examples/streaming/README.ipynb index 025246237..1273b26c4 100644 --- a/docs/examples/streaming/README.ipynb +++ b/docs/examples/streaming/README.ipynb @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -121,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -138,8 +138,7 @@ "{\n", " \"debug\": false,\n", " \"parallel_workers\": 0,\n", - " \"gzip_enabled\": false,\n", - " \"metrics_endpoint\": null\n", + " \"gzip_enabled\": false\n", "}\n" ] }, @@ -150,8 +149,7 @@ "Note the currently there are three main limitations of the streaming support in MLServer:\n", "\n", "- distributed workers are not supported (i.e., the `parallel_workers` setting should be set to `0`)\n", - "- `gzip` middleware is not supported for REST (i.e., `gzip_enabled` setting should be set to `false`)\n", - "- metrics endpoint is not available (i.e. `metrics_endpoint` is also disabled for streaming for gRPC)" + "- `gzip` middleware is not supported for REST (i.e., `gzip_enabled` setting should be set to `false`)" ] }, { @@ -163,7 +161,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -227,14 +225,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Writing generate-request.json\n" + "Overwriting generate-request.json\n" ] } ], @@ -272,9 +270,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['What']\n", + "[' is']\n", + "[' the']\n", + "[' capital']\n", + "[' of']\n", + "[' France?']\n" + ] + } + ], "source": [ "import httpx\n", "from httpx_sse import connect_sse\n", @@ -301,9 +312,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['What']\n", + "[' is']\n", + "[' the']\n", + "[' capital']\n", + "[' of']\n", + "[' France?']\n" + ] + } + ], "source": [ "import grpc\n", "import mlserver.types as types\n", @@ -315,7 +339,7 @@ "inference_request = types.InferenceRequest.parse_file(\"./generate-request.json\")\n", "\n", "# need to convert from string to bytes for grpc\n", - "inference_request.inputs[0] = StringCodec.encode_input(\"prompt\", inference_request.inputs[0].data.__root__)\n", + "inference_request.inputs[0] = StringCodec.encode_input(\"prompt\", inference_request.inputs[0].data.root)\n", "inference_request_g = converters.ModelInferRequestConverter.from_types(\n", " inference_request, model_name=\"text-model\", model_version=None\n", ")\n", @@ -338,11 +362,6 @@ "source": [ "Note that for gRPC, the request is transformed into an async generator which is then passed to the `ModelStreamInfer` method. The response is also an async generator which can be iterated over to get the response." ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/docs/examples/streaming/README.md b/docs/examples/streaming/README.md index 7acdf2090..d91aa7492 100644 --- a/docs/examples/streaming/README.md +++ b/docs/examples/streaming/README.md @@ -78,8 +78,7 @@ The next step will be to create 2 configuration files: { "debug": false, "parallel_workers": 0, - "gzip_enabled": false, - "metrics_endpoint": null + "gzip_enabled": false } ``` @@ -88,7 +87,6 @@ Note the currently there are three main limitations of the streaming support in - distributed workers are not supported (i.e., the `parallel_workers` setting should be set to `0`) - `gzip` middleware is not supported for REST (i.e., `gzip_enabled` setting should be set to `false`) -- metrics endpoint is not available (i.e. `metrics_endpoint` is also disabled for streaming for gRPC) #### model-settings.json @@ -195,7 +193,7 @@ import mlserver.grpc.dataplane_pb2_grpc as dataplane inference_request = types.InferenceRequest.parse_file("./generate-request.json") # need to convert from string to bytes for grpc -inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__) +inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.root) inference_request_g = converters.ModelInferRequestConverter.from_types( inference_request, model_name="text-model", model_version=None ) @@ -213,5 +211,3 @@ async with grpc.aio.insecure_channel("localhost:8081") as grpc_channel: ``` Note that for gRPC, the request is transformed into an async generator which is then passed to the `ModelStreamInfer` method. The response is also an async generator which can be iterated over to get the response. - - diff --git a/docs/examples/streaming/settings.json b/docs/examples/streaming/settings.json index ec853b3ba..3a95c2882 100644 --- a/docs/examples/streaming/settings.json +++ b/docs/examples/streaming/settings.json @@ -2,6 +2,5 @@ { "debug": false, "parallel_workers": 0, - "gzip_enabled": false, - "metrics_endpoint": null + "gzip_enabled": false } diff --git a/docs/examples/streaming/text_model.py b/docs/examples/streaming/text_model.py index 4475b3c92..35b167bb5 100644 --- a/docs/examples/streaming/text_model.py +++ b/docs/examples/streaming/text_model.py @@ -1,3 +1,4 @@ + import asyncio from typing import AsyncIterator from mlserver import MLModel @@ -7,19 +8,6 @@ class TextModel(MLModel): - async def predict(self, payload: InferenceRequest) -> InferenceResponse: - text = StringCodec.decode_input(payload.inputs[0])[0] - return InferenceResponse( - model_name=self._settings.name, - outputs=[ - StringCodec.encode_output( - name="output", - payload=[text], - use_bytes=True, - ), - ], - ) - async def predict_stream( self, payloads: AsyncIterator[InferenceRequest] ) -> AsyncIterator[InferenceResponse]: diff --git a/docs/user-guide/streaming.md b/docs/user-guide/streaming.md index 41dec0b03..a576e6a3e 100644 --- a/docs/user-guide/streaming.md +++ b/docs/user-guide/streaming.md @@ -32,4 +32,3 @@ There are three main limitations of the streaming support in MLServer: - the `parallel_workers` setting should be set to `0` to disable distributed workers (to be addressed in future releases) - for REST, the `gzip_enabled` setting should be set to `false` to disable GZIP compression, as streaming is not compatible with GZIP compression (see issue [here]( https://github.com/encode/starlette/issues/20#issuecomment-704106436)) -- `metrics_endpoint` is also disabled for streaming for gRPC (to be addressed in future releases) \ No newline at end of file From f95f4ca3ef680dd1b8129c8187905e78a6bee2e2 Mon Sep 17 00:00:00 2001 From: Robert Samoilescu Date: Fri, 12 Jul 2024 15:44:13 +0100 Subject: [PATCH 4/6] Fix linting. --- docs/examples/streaming/README.ipynb | 2 +- docs/examples/streaming/text_model.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/examples/streaming/README.ipynb b/docs/examples/streaming/README.ipynb index 1273b26c4..13755c6d9 100644 --- a/docs/examples/streaming/README.ipynb +++ b/docs/examples/streaming/README.ipynb @@ -380,7 +380,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/examples/streaming/text_model.py b/docs/examples/streaming/text_model.py index 35b167bb5..d851f3bb9 100644 --- a/docs/examples/streaming/text_model.py +++ b/docs/examples/streaming/text_model.py @@ -1,4 +1,3 @@ - import asyncio from typing import AsyncIterator from mlserver import MLModel From 5e756e078b1fe58384927589dc4a72f6896102e9 Mon Sep 17 00:00:00 2001 From: Robert Samoilescu Date: Mon, 15 Jul 2024 11:49:28 +0100 Subject: [PATCH 5/6] Addressed PR review comments. --- mlserver/grpc/interceptors.py | 27 +++++++++++++++++---------- tests/grpc/test_interceptor.py | 2 +- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/mlserver/grpc/interceptors.py b/mlserver/grpc/interceptors.py index ccde3509c..99de6d500 100644 --- a/mlserver/grpc/interceptors.py +++ b/mlserver/grpc/interceptors.py @@ -1,9 +1,12 @@ -from typing import Awaitable, Callable, Tuple +from typing import Awaitable, AsyncIterator, Callable, Tuple, Optional from functools import partial from timeit import default_timer +from mlserver.grpc import dataplane_pb2 as pb from grpc.aio import ServerInterceptor, ServicerContext from grpc import HandlerCallDetails, RpcMethodHandler, RpcError, StatusCode + +from prometheus_client import Counter from py_grpc_prometheus.prometheus_server_interceptor import ( grpc_utils, PromServerInterceptor as _PromServerInterceptor, @@ -66,7 +69,9 @@ def _metrics_wrapper( """ grpc_service_name, grpc_method_name, _ = method_call - async def new_behavior(request, servicer_context): + async def new_behavior( + request: pb.ModelMetadataRequest, servicer_context: ServicerContext + ) -> Optional[pb.ModelMetadataRequest]: response = None try: start = default_timer() @@ -139,8 +144,9 @@ async def new_behavior(request, servicer_context): raise e async def new_behavior_stream( - request_async_iterator, servicer_context: ServicerContext - ): + request_async_iterator: AsyncIterator[pb.ModelInferRequest], + servicer_context: ServicerContext, + ) -> AsyncIterator[pb.ModelInferRequest]: response_async_iterator = None try: grpc_type = grpc_utils.get_method_type( @@ -156,12 +162,9 @@ async def new_behavior_stream( ) # wrap the original behavior with the metrics - sent_metric = self._interceptor._metrics[ - "grpc_server_stream_msg_sent" - ] response_async_iterator = wrap_async_iterator_inc_counter( behavior(request_async_iterator, servicer_context), - sent_metric, + self._interceptor._metrics["grpc_server_stream_msg_sent"], grpc_type, grpc_service_name, grpc_method_name, @@ -236,8 +239,12 @@ def _compute_status_code(self, servicer_context: ServicerContext) -> StatusCode: async def wrap_async_iterator_inc_counter( - iterator, counter, grpc_type, grpc_service_name, grpc_method_name -): + iterator: AsyncIterator[pb.ModelInferRequest], + counter: Counter, + grpc_type: str, + grpc_service_name: str, + grpc_method_name: str, +) -> AsyncIterator[pb.ModelInferRequest]: """Wraps an async iterator and collect metrics.""" async for item in iterator: diff --git a/tests/grpc/test_interceptor.py b/tests/grpc/test_interceptor.py index 2ef39d5e0..ca3a7f007 100644 --- a/tests/grpc/test_interceptor.py +++ b/tests/grpc/test_interceptor.py @@ -76,7 +76,7 @@ async def get_stream_request(request): yield request # send 10 requests - num_requests = 1 + num_requests = 10 for _ in range(num_requests): _ = [ _ From 55a271e1f3cb22cafea45572cedcef883909ba5d Mon Sep 17 00:00:00 2001 From: Robert Samoilescu Date: Mon, 15 Jul 2024 12:25:31 +0100 Subject: [PATCH 6/6] Run grpc interceptor separately. --- tox.ini | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index e4a6e37be..034cfaa05 100644 --- a/tox.ini +++ b/tox.ini @@ -22,12 +22,17 @@ commands = python -m pytest {posargs} -n auto \ {toxinidir}/tests \ --ignore={toxinidir}/tests/kafka \ - --ignore={toxinidir}/tests/parallel + --ignore={toxinidir}/tests/parallel \ + --ignore={toxinidir}/tests/grpc/test_interceptor.py # kafka and parallel tests are failing for macos when running in parallel # with the entire test suite. So, we run them separately. + # Also, we run the grpc interceptor test separately because + # other tests will interfere with the metrics counter when + # running in parallel. python -m pytest {posargs} \ {toxinidir}/tests/kafka \ - {toxinidir}/tests/parallel + {toxinidir}/tests/parallel \ + {toxinidir}/tests/grpc/test_interceptor.py set_env = GITHUB_SERVER_URL = {env:GITHUB_SERVER_URL:https\://github.com} GITHUB_REPOSITORY = {env:GITHUB_REPOSITORY:SeldonIO/MLServer}