From fac77b0a0d388446866253eb8e548531426a47b0 Mon Sep 17 00:00:00 2001 From: Wallas Henrique Date: Mon, 21 Oct 2024 18:49:41 -0300 Subject: [PATCH] [Frontend] Don't log duplicate error stacktrace for every request in the batch (#9023) Signed-off-by: Wallas Santos Signed-off-by: Amit Garg --- tests/mq_llm_engine/test_error_handling.py | 51 +++++++++++++++++----- vllm/engine/multiprocessing/client.py | 12 +++++ 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 616a15a1328de..205ab00aa6b17 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -59,15 +59,7 @@ async def test_evil_forward(tmp_socket): await asyncio.sleep(2.0) await client.check_health() - # Throws an error in first forward pass. - with pytest.raises(RAISED_ERROR): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=uuid.uuid4()): - pass - assert client.errored - - # Engine is errored, should get ENGINE_DEAD_ERROR. + # Throws an error that should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), @@ -149,7 +141,7 @@ async def test_failed_abort(tmp_socket): client = await engine.make_client() assert client.is_running - # Firsh check health should work. + # First check health should work. await client.check_health() # Trigger an abort on the client side. @@ -174,6 +166,45 @@ async def test_failed_abort(tmp_socket): client.close() +@pytest.mark.asyncio +async def test_batch_error(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # First check health should work. + await client.check_health() + + # Batch of requests + async def do_generate(client): + # min_tokens=2048 to keep busy the engine busy + # to get enough time to get process a request + # that will crash the engine + params = SamplingParams(min_tokens=2048, max_tokens=2048) + async for _ in client.generate(prompt="Hello my name is", + sampling_params=params, + request_id=uuid.uuid4()): + pass + + tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] + + # This request will force a processing batch to raise + # an exception and next the engine get errored + await client.abort(request_id="foo") + + # The batch of those request failed, then they + # should get the same exception as a MQEngineDeadError. + errors = await asyncio.gather(*tasks, return_exceptions=True) + for e in errors: + assert isinstance(e, MQEngineDeadError) + assert "KeyError" in repr(e) + + client.close() + + @pytest.mark.asyncio async def test_bad_request(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 9732c7098e160..9e5a6b21f4c18 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -204,8 +204,20 @@ async def run_output_handler_loop(self): # (and record only the first one) if is_engine_errored and not self._errored_with: self._errored_with = exception + # If engine is errored, no matter the type of exception + # it will no longer be able to receive new requests, + # therefore we have to inform that the current + # processed requests failed as well. Send back a dead + # engine error give this feedback and also give a + # 'hint' to the server to shutdown next. + exception = self.dead_error if request_id is None: + # If request_id is None, then the engine raised an + # exception for a batch, and we may not know the + # request that caused it, neither if it was actually + # caused by any of them (e.g. CUDA OOM). Therefore we + # broadcast the same exception for all requests. for queue_i in tuple(self.output_queues.values()): queue_i.put_nowait(exception) else: