Skip to content

Commit

Permalink
[Frontend] Kill the server on engine death (#6594)
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Aug 8, 2024
1 parent 5fb4a3f commit 21b9c49
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 14 deletions.
47 changes: 47 additions & 0 deletions tests/entrypoints/openai/test_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json
import os

import openai
import pytest

from ...utils import RemoteOpenAIServer

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.mark.asyncio
async def test_shutdown_on_engine_failure(tmp_path):
# Use a bad adapter to crash the engine
# (This test will fail when that bug is fixed)
adapter_path = tmp_path / "bad_adapter"
os.mkdir(adapter_path)
with open(adapter_path / "adapter_model_config.json", "w") as f:
json.dump({"not": "real"}, f)
with open(adapter_path / "adapter_model.safetensors", "wb") as f:
f.write(b"this is fake")

# dtype, max-len etc set so that this can run in CI
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--enable-lora",
"--lora-modules",
f"bad-adapter={tmp_path / 'bad_adapter'}",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
client = remote_server.get_async_client()

with pytest.raises(openai.APIConnectionError):
# This crashes the engine
await client.completions.create(model="bad-adapter",
prompt="Hello, my name is")

# Now the server should shut down
return_code = remote_server.proc.wait(timeout=1)
assert return_code is not None
6 changes: 4 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _log_task_completion(task: asyncio.Task,
error_callback(exception)
raise AsyncEngineDeadError(
"Task finished unexpectedly. This should never happen! "
"Please open an issue on Github. See stack trace above for the"
"Please open an issue on Github. See stack trace above for the "
"actual cause.") from e


Expand Down Expand Up @@ -132,7 +132,9 @@ def propagate_exception(self,
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else:
for rid, stream in self._request_streams.items():
# NB: list() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for rid, stream in list(self._request_streams.items()):
stream.put(exc)
self.abort_request(rid)

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def run_server(args: Namespace,

shutdown_task = await serve_http(
app,
engine=engine,
host=args.host,
port=args.port,
log_level=args.log_level,
Expand Down
44 changes: 42 additions & 2 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import asyncio
import signal
from http import HTTPStatus
from typing import Any

import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Response

from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger

logger = init_logger(__name__)


async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
Expand All @@ -23,6 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):

config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine)

loop = asyncio.get_running_loop()

Expand All @@ -44,3 +50,37 @@ async def dummy_shutdown() -> None:
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
return server.shutdown()


def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
engine: AsyncEngineClient) -> None:
"""Adds handlers for fatal errors that should crash the server"""

@app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

@app.exception_handler(AsyncEngineDeadError)
async def engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:

shutdown_task = await serve_http(
app,
engine=async_engine_client,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
Expand Down
26 changes: 24 additions & 2 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def setup(self):

# Wait until server is ready.
await self.wait_for_server()
self._errored = False

# Get the configs.
self.model_config = await self._get_model_config_rpc()
Expand Down Expand Up @@ -169,15 +170,15 @@ async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")

async def _get_lora_config_rpc(self):
async def _get_lora_config_rpc(self) -> LoRAConfig:
"""Get LoRAConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")

async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
async def _is_tracing_enabled_rpc(self) -> bool:
"""Get is_tracing_enabled flag from the RPCServer"""

return await self._send_get_data_rpc_request(
Expand All @@ -200,6 +201,18 @@ async def do_log_stats(self):
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")

@property
def is_running(self) -> bool:
return not self._errored

@property
def is_stopped(self) -> bool:
return self._errored

@property
def errored(self) -> bool:
return self._errored

async def generate(
self,
inputs: PromptInputs,
Expand Down Expand Up @@ -233,6 +246,15 @@ async def generate(
request_output = cloudpickle.loads(message)

if isinstance(request_output, Exception):
# On exception, check if the server is still healthy.
# Use this to set the sync `is_running` and `errored`
# properties.
try:
await self.check_health()
except Exception:
self._errored = True
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output

finished = request_output.finished
Expand Down
19 changes: 11 additions & 8 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ async def is_server_ready(self, identity):

async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)

# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
try:
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
except Exception:
logger.warning("Failed to abort request %s", request.request_id)
finally:
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
Expand Down Expand Up @@ -335,6 +336,11 @@ def get_default_config_root():
"VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),

# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),

# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
Expand Down

0 comments on commit 21b9c49

Please sign in to comment.