Skip to content

Commit

Permalink
[Bugfix][core] replace heartbeat with pid check (#9818)
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Oct 30, 2024
1 parent 9ff4511 commit 3b3f1e7
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 62 deletions.
27 changes: 26 additions & 1 deletion tests/mq_llm_engine/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.utils import FlexibleArgumentParser

MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, enforce_eager=True)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"

Expand Down Expand Up @@ -266,3 +266,28 @@ async def test_mp_cuda_init():

async with build_async_engine_client(args):
pass


@pytest.mark.asyncio
async def test_engine_process_death(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:

client = await engine.make_client()
assert client.is_running

# kill the engine process
engine.proc.kill()

# Generate call should fail
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass

# And the health check should show the engine is dead
with pytest.raises(RuntimeError, match="Engine process .* died"):
await client.check_health()

client.close()
2 changes: 1 addition & 1 deletion tests/mq_llm_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __exit__(self, exc_type, exc_value, traceback):

async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config)
client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
while True:
try:
await client.setup()
Expand Down
29 changes: 19 additions & 10 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Optional, Union, cast, overload)

import cloudpickle
import psutil
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
Expand Down Expand Up @@ -77,7 +78,8 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy
"""

def __init__(self, ipc_path: str, engine_config: EngineConfig):
def __init__(self, ipc_path: str, engine_config: EngineConfig,
engine_pid: int):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None

Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
self._engine_process = psutil.Process(engine_pid)

@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
Expand All @@ -131,21 +134,22 @@ def get_data_socket(self) -> Iterator[Socket]:
socket.close(linger=0)

async def run_heartbeat_loop(self, timeout: int):
"""Background loop that continually listens to the RPCServer for
heartbeats.
"""Background loop that continually checks to ensure the engine process
is still alive.
"""
try:
while True:
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
# No heartbeat was received. Set error and exit the loop
# Check if the engine process is running:
if not self._engine_process.is_running() or (
self._engine_process.status() == psutil.STATUS_ZOMBIE):
# NB: is_running() returns True for zombies
self._set_errored(
TimeoutError("No heartbeat received "
"from MQLLMEngine"))
logger.debug("Shutting down MQLLMEngineClient check "
"health loop due to timeout")
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) "
"died."))
break

else:
if await self.heartbeat_socket.poll(timeout=timeout):
# Heartbeat received- check the message
await self._check_success(
error_message="Heartbeat failed.",
Expand All @@ -156,6 +160,11 @@ async def run_heartbeat_loop(self, timeout: int):
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")

except psutil.NoSuchProcess:
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) died."))

except Exception as e:
self._set_errored(e)

Expand Down
59 changes: 11 additions & 48 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pickle
import signal
import threading
import time
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

Expand All @@ -21,7 +19,7 @@
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1
from vllm.envs import VLLM_USE_V1
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
Expand Down Expand Up @@ -108,20 +106,6 @@ def __init__(self,
# Error state.
self._errored_with: Optional[BaseException] = None

# Heartbeat thread
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
daemon=True)
self._heartbeat_stop_event = threading.Event()
# The heartbeat needs to be faster than what the client will wait for
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0

self._last_alive_time = time.time()
# The heartbeats can tolerate a long period of the engine chugging
# away at a generation request.
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0

@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
Expand Down Expand Up @@ -157,8 +141,6 @@ def start(self):
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting heartbeat thread")
self.heartbeat_thread.start()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
Expand All @@ -172,7 +154,6 @@ def start(self):
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self._heartbeat_stop_event.set()
self.ctx.destroy(linger=0)
del self.engine

Expand Down Expand Up @@ -211,11 +192,12 @@ def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

while True:
self._alive()
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self._alive()
# When there's no work, check on engine health and send
# health status back to client
self._health_check()
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")

Expand Down Expand Up @@ -314,32 +296,16 @@ def _handle_abort_request(self, request: RPCAbortRequest):
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)

def _heartbeat_loop(self):
while not self._heartbeat_stop_event.wait(
timeout=self.heartbeat_interval_seconds):
# Loops until the stop event is set
self._heartbeat()

logger.debug("Exiting MQLLMEngine heartbeat thread")

def _heartbeat(self):
def _health_check(self):
# Send unhealthy if engine has already errored
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)

# Check for life of the main loop
elif time.time() - self._last_alive_time > self.last_alive_threshold:
self._send_unhealthy(RuntimeError("Engine loop has died"))

else:
# Otherwise- check health of the engine
# self.engine.check_health() raises on unhealthy
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
try:
self.engine.check_health()
self._send_healthy()
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)

def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
Expand Down Expand Up @@ -369,9 +335,6 @@ def _set_errored(self, e: BaseException):
if self._errored_with is None:
self._errored_with = e

def _alive(self):
self._last_alive_time = time.time()

def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()
Expand Down
7 changes: 5 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,16 @@ async def build_async_engine_client_from_engine_args(
UsageContext.OPENAI_API_SERVER,
ipc_path))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start"
logger.info("Started engine process with PID %d", engine_pid)

# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config,
engine_pid)

try:
while True:
Expand Down

0 comments on commit 3b3f1e7

Please sign in to comment.