diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 34c161e9395ae..2a67609654ea7 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -18,6 +18,12 @@ IPC_DATA_EXT = "_data_socket" +# Generic exception when the engine +# fails to process a batch +class MQEngineBatchError(Exception): + pass + + class MQEngineDeadError(RuntimeError): pass diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index b0d061dbab4a1..3924dde80e66b 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -21,9 +21,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse, + VLLM_RPC_SUCCESS_STR, + MQEngineBatchError, MQEngineDeadError, + RPCAbortRequest, RPCError, + RPCProcessRequest, RPCStartupRequest, + RPCStartupResponse, RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -203,8 +205,23 @@ async def run_output_handler_loop(self): self._errored_with = exception if request_id is None: + for queue_i in tuple(self.output_queues.values()): - queue_i.put_nowait(exception) + + msg = str("A batch generation failed. Inspect the " + "stacktrace to find the original error: " + f"{repr(exception)}") + # If it is a runtime exception, we assume that + # the engine is already dead, let's pass this + # information ahead. Otherwise we just set as + # batch error, and maybe the engine is still + # up running. + # For runtime exceptions vLLM process will + # shutdown immediately. + batch_error = MQEngineDeadError(msg) if isinstance( + exception, + RuntimeError) else MQEngineBatchError(msg) + queue_i.put_nowait(batch_error) else: queue = self.output_queues.get(request_id) if queue is not None: diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 5dcf50bd1b0a1..2591cf2195e36 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -8,7 +8,7 @@ from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.engine.multiprocessing import MQEngineBatchError, MQEngineDeadError from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -28,7 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) - _add_shutdown_handlers(app, server) + _add_exception_handles(app, server) loop = asyncio.get_running_loop() @@ -58,8 +58,9 @@ async def dummy_shutdown() -> None: return server.shutdown() -def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: - """Adds handlers for fatal errors that should crash the server""" +def _add_exception_handles(app: FastAPI, server: uvicorn.Server) -> None: + """Adds handlers for custom errors that may crash the server or + improve the readability of the stacktrace""" @app.exception_handler(RuntimeError) async def runtime_error_handler(request: Request, __): @@ -101,3 +102,12 @@ async def mq_engine_dead_handler(_, __): server.should_exit = True return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineBatchError) + async def mq_engine_batch_error_handler(_, err): + """Log the error and pass an internal server error. + This error might be propagated to all requests of + a batch that failed to generate""" + logger.error("%s", repr(err)) + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)