Skip to content

Commit

Permalink
Addressed PR review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertSamoilescu committed Jul 15, 2024
1 parent f95f4ca commit 5e756e0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
27 changes: 17 additions & 10 deletions mlserver/grpc/interceptors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Awaitable, Callable, Tuple
from typing import Awaitable, AsyncIterator, Callable, Tuple, Optional
from functools import partial
from timeit import default_timer

from mlserver.grpc import dataplane_pb2 as pb
from grpc.aio import ServerInterceptor, ServicerContext
from grpc import HandlerCallDetails, RpcMethodHandler, RpcError, StatusCode

from prometheus_client import Counter
from py_grpc_prometheus.prometheus_server_interceptor import (
grpc_utils,
PromServerInterceptor as _PromServerInterceptor,
Expand Down Expand Up @@ -66,7 +69,9 @@ def _metrics_wrapper(
"""
grpc_service_name, grpc_method_name, _ = method_call

async def new_behavior(request, servicer_context):
async def new_behavior(
request: pb.ModelMetadataRequest, servicer_context: ServicerContext
) -> Optional[pb.ModelMetadataRequest]:
response = None
try:
start = default_timer()
Expand Down Expand Up @@ -139,8 +144,9 @@ async def new_behavior(request, servicer_context):
raise e

async def new_behavior_stream(
request_async_iterator, servicer_context: ServicerContext
):
request_async_iterator: AsyncIterator[pb.ModelInferRequest],
servicer_context: ServicerContext,
) -> AsyncIterator[pb.ModelInferRequest]:
response_async_iterator = None
try:
grpc_type = grpc_utils.get_method_type(
Expand All @@ -156,12 +162,9 @@ async def new_behavior_stream(
)

# wrap the original behavior with the metrics
sent_metric = self._interceptor._metrics[
"grpc_server_stream_msg_sent"
]
response_async_iterator = wrap_async_iterator_inc_counter(
behavior(request_async_iterator, servicer_context),
sent_metric,
self._interceptor._metrics["grpc_server_stream_msg_sent"],
grpc_type,
grpc_service_name,
grpc_method_name,
Expand Down Expand Up @@ -236,8 +239,12 @@ def _compute_status_code(self, servicer_context: ServicerContext) -> StatusCode:


async def wrap_async_iterator_inc_counter(
iterator, counter, grpc_type, grpc_service_name, grpc_method_name
):
iterator: AsyncIterator[pb.ModelInferRequest],
counter: Counter,
grpc_type: str,
grpc_service_name: str,
grpc_method_name: str,
) -> AsyncIterator[pb.ModelInferRequest]:
"""Wraps an async iterator and collect metrics."""

async for item in iterator:
Expand Down
2 changes: 1 addition & 1 deletion tests/grpc/test_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def get_stream_request(request):
yield request

# send 10 requests
num_requests = 1
num_requests = 10
for _ in range(num_requests):
_ = [
_
Expand Down

0 comments on commit 5e756e0

Please sign in to comment.