Skip to content

Commit

Permalink
[Frontend] Don't log duplicate error stacktrace for every request in …
Browse files Browse the repository at this point in the history
…the batch (#9023)

Signed-off-by: Wallas Santos <[email protected]>
  • Loading branch information
wallashss authored Oct 21, 2024
1 parent 15713e3 commit 711f3a7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
51 changes: 41 additions & 10 deletions tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,7 @@ async def test_evil_forward(tmp_socket):
await asyncio.sleep(2.0)
await client.check_health()

# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored

# Engine is errored, should get ENGINE_DEAD_ERROR.
# Throws an error that should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
Expand Down Expand Up @@ -149,7 +141,7 @@ async def test_failed_abort(tmp_socket):
client = await engine.make_client()
assert client.is_running

# Firsh check health should work.
# First check health should work.
await client.check_health()

# Trigger an abort on the client side.
Expand All @@ -174,6 +166,45 @@ async def test_failed_abort(tmp_socket):
client.close()


@pytest.mark.asyncio
async def test_batch_error(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:

client = await engine.make_client()
assert client.is_running

# First check health should work.
await client.check_health()

# Batch of requests
async def do_generate(client):
# min_tokens=2048 to keep busy the engine busy
# to get enough time to get process a request
# that will crash the engine
params = SamplingParams(min_tokens=2048, max_tokens=2048)
async for _ in client.generate(prompt="Hello my name is",
sampling_params=params,
request_id=uuid.uuid4()):
pass

tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]

# This request will force a processing batch to raise
# an exception and next the engine get errored
await client.abort(request_id="foo")

# The batch of those request failed, then they
# should get the same exception as a MQEngineDeadError.
errors = await asyncio.gather(*tasks, return_exceptions=True)
for e in errors:
assert isinstance(e, MQEngineDeadError)
assert "KeyError" in repr(e)

client.close()


@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
Expand Down
12 changes: 12 additions & 0 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,20 @@ async def run_output_handler_loop(self):
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
exception = self.dead_error

if request_id is None:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else:
Expand Down

0 comments on commit 711f3a7

Please sign in to comment.