Skip to content

Commit

Permalink
api: add client timeouts for the ZeroMQ server (#897)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 16, 2024
1 parent 908ff75 commit a00ab49
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 30 deletions.
6 changes: 6 additions & 0 deletions aphrodite/common/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
APHRODITE_TEST_FORCE_FP8_MARLIN: bool = False
APHRODITE_ALLOW_ENGINE_USE_RAY: bool = False
APHRODITE_PLUGINS: Optional[List[str]] = None
APHRODITE_RPC_GET_DATA_TIMEOUT_MS: int = 5000


def get_default_cache_root():
Expand Down Expand Up @@ -362,6 +363,11 @@ def get_default_config_root():
(os.environ.get("APHRODITE_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),

# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"APHRODITE_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("APHRODITE_RPC_GET_DATA_TIMEOUT_MS", "5000")),

# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
"APHRODITE_ALLOW_ENGINE_USE_RAY":
Expand Down
5 changes: 3 additions & 2 deletions aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from distutils.util import strtobool
from http import HTTPStatus
from typing import AsyncGenerator, AsyncIterator, List, Optional, Set, Tuple
Expand Down Expand Up @@ -98,7 +98,8 @@ async def lifespan(app: FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
with suppress(Exception):
await async_engine_client.do_log_stats()

if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
Expand Down
3 changes: 0 additions & 3 deletions aphrodite/endpoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

# Success string used for RPC instructions.
APHRODITE_RPC_SUCCESS_STR = "SUCCESS"
# Timeouts.
APHRODITE_RPC_SERVER_START_TIMEOUT_MS = 1000
APHRODITE_RPC_HEALTH_TIMEOUT_MS = 10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
APHRODITE_RPC_SOCKET_LIMIT_CUTOFF = 2000
# HWM is set to Infinity.
Expand Down
82 changes: 57 additions & 25 deletions aphrodite/endpoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Optional
from uuid import uuid4

Expand All @@ -10,13 +10,15 @@

from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from aphrodite.common.envs import APHRODITE_RPC_GET_DATA_TIMEOUT_MS
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.endpoints.openai.rpc import (
APHRODITE_RPC_HEALTH_TIMEOUT_MS, APHRODITE_RPC_SERVER_START_TIMEOUT_MS,
APHRODITE_RPC_SOCKET_LIMIT_CUTOFF, APHRODITE_RPC_SUCCESS_STR,
APHRODITE_RPC_ZMQ_HWM, RPC_REQUEST_TYPE, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_SOCKET_LIMIT_CUTOFF,
APHRODITE_RPC_SUCCESS_STR,
APHRODITE_RPC_ZMQ_HWM,
RPC_REQUEST_TYPE, RPCAbortRequest,
RPCGenerateRequest,
RPCUtilityRequest)
from aphrodite.inputs import PromptInputs
from aphrodite.lora.request import LoRARequest
from aphrodite.prompt_adapter.request import PromptAdapterRequest
Expand All @@ -26,6 +28,18 @@
# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"


class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""


class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
Expand Down Expand Up @@ -73,6 +87,8 @@ class AsyncEngineRPCClient:

def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self._data_timeout = APHRODITE_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
Expand Down Expand Up @@ -125,7 +141,6 @@ async def setup(self):

# Wait until server is ready.
await self._wait_for_server_rpc()
self._errored = False

# Get the configs.
self.model_config = await self._get_model_config_rpc()
Expand All @@ -152,6 +167,13 @@ def close(self):
@contextmanager
def to_proxy_socket(self):
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
Expand All @@ -171,9 +193,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
# Ping RPCServer with a request.
await socket.send_multipart([cloudpickle.dumps(request)])

# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())

if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data

if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
Expand All @@ -190,25 +221,28 @@ async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action."""

async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE,
timeout=None):
request: RPC_REQUEST_TYPE):

await socket.send_multipart([cloudpickle.dumps(request)])
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"Server didn't reply within {timeout} ms")

if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")

return cloudpickle.loads(await socket.recv())

# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)

# Use existing socket connection.
else:
response = await do_rpc_call(socket, request, timeout)
response = await do_rpc_call(socket, request)

if not isinstance(
response, str) or response != APHRODITE_RPC_SUCCESS_STR:
Expand All @@ -231,8 +265,7 @@ async def _wait_for_server_rpc(self):

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server",
timeout=APHRODITE_RPC_SERVER_START_TIMEOUT_MS)
error_message="Unable to start RPC Server")

async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
Expand Down Expand Up @@ -276,17 +309,17 @@ async def _get_lora_config_rpc(self) -> LoRAConfig:

async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""

await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")

async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")

@property
def is_running(self) -> bool:
Expand Down Expand Up @@ -358,7 +391,6 @@ async def check_health(self,
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
timeout=APHRODITE_RPC_HEALTH_TIMEOUT_MS,
socket=socket)

async def encode(self, *args,
Expand Down
Empty file.
115 changes: 115 additions & 0 deletions tests/endpoints/openai/rpc/test_zmq_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid

import pytest
import pytest_asyncio

from aphrodite.endpoints.openai.rpc.client import (AsyncEngineRPCClient,
RPCClientClosedError)
from aphrodite.endpoints.openai.rpc.server import AsyncEngineRPCServer
from aphrodite.engine.async_aphrodite import AsyncAphrodite


@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"


@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
dummy_engine = unittest.mock.AsyncMock()

def dummy_engine_builder(*args, **kwargs):
return dummy_engine

with monkeypatch.context() as m:
m.setattr(AsyncAphrodite, "from_engine_args", dummy_engine_builder)
server = AsyncEngineRPCServer(None, rpc_path=tmp_socket)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
try:
yield server
finally:
server_task.cancel()
server.cleanup()


@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
# Sanity check: the server is connected
await client._wait_for_server_rpc()
try:
yield client
finally:
client.close()


@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(
monkeypatch, dummy_server, client: AsyncEngineRPCClient
):
with monkeypatch.context() as m:
# Make the server _not_ reply with a model config
m.setattr(dummy_server, "get_config", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task = asyncio.get_running_loop().create_task(client.setup())
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(
monkeypatch, dummy_server, client: AsyncEngineRPCClient
):
with monkeypatch.context() as m:
# Hang all abort requests
m.setattr(dummy_server, "abort", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# Ensure the client doesn't hang
client_task = asyncio.get_running_loop().create_task(
client.abort("test request id")
)
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
monkeypatch, dummy_server, client: AsyncEngineRPCClient
):
with monkeypatch.context() as m:
# Make the server raise some random exception
exception = RuntimeError("Client test exception")

def raiser():
raise exception

m.setattr(dummy_server.engine, "get_model_config", raiser)
m.setattr(client, "_data_timeout", 10)
client_task = asyncio.get_running_loop().create_task(client.setup())
# And ensure the task completes, raising the exception
with pytest.raises(RuntimeError, match=str(exception)):
await asyncio.wait_for(client_task, timeout=0.05)


@pytest.mark.asyncio
async def test_client_errors_after_closing(
monkeypatch, dummy_server, client: AsyncEngineRPCClient
):
client.close()
# Healthchecks and generate requests will fail with explicit errors
with pytest.raises(RPCClientClosedError):
await client.check_health()
with pytest.raises(RPCClientClosedError):
async for _ in client.generate(None, None, None):
pass
# But no-ops like aborting will pass
await client.abort("test-request-id")
await client.do_log_stats()

0 comments on commit a00ab49

Please sign in to comment.