diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 6bf6240e..64c04ea1 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -14,6 +14,7 @@ from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing import MQEngineDeadError from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.inputs import LLMInputs from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -149,6 +150,9 @@ async def _handle_exception( service_metrics.count_request_failure(FailureReasonLabel.GENERATE) else: service_metrics.count_request_failure(FailureReasonLabel.UNKNOWN) + if isinstance(e, MQEngineDeadError): + logger.error(e) + return logger.exception("%s failed", func.__name__) raise e diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py index a2b2d5ed..aea036b1 100644 --- a/tests/test_grpc_server.py +++ b/tests/test_grpc_server.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from .utils import GrpcClient @@ -90,3 +92,34 @@ def test_request_id(grpc_client, mocker): spy.assert_called_once() assert spy.spy_return == request_id.hex + + +def test_error_handling(mocker): + from vllm.engine.multiprocessing import MQEngineDeadError + + from vllm_tgis_adapter.grpc.grpc_server import _handle_exception, logger + + def dummy_func(): + pass + + class DummyEngine: + errored = False + is_running = True + + class DummyArg: + engine = DummyEngine() + + # General error handling + key_error = KeyError() + dummy_arg_0 = DummyArg() + with pytest.raises(KeyError): + asyncio.run(_handle_exception(key_error, dummy_func, dummy_arg_0)) + + engine_error = MQEngineDeadError("foo:bar") + + # Engine error handling + spy = mocker.spy(logger, "error") + + # Does not raises exception + asyncio.run(_handle_exception(engine_error, dummy_func, dummy_arg_0)) + spy.assert_called_once_with(engine_error)