Skip to content

Commit

Permalink
api: fix crashes under very high loads (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 12, 2024
1 parent 9fd2bfa commit b5aa110
Show file tree
Hide file tree
Showing 8 changed files with 677 additions and 138 deletions.
10 changes: 7 additions & 3 deletions aphrodite/endpoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
rpc_path = get_open_zmq_ipc_path()
logger.info(f"Multiprocessing frontend to use {rpc_path} for RPC Path."
)

# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore

# Start RPCServer in separate process (holds the AsyncAphrodite).
context = multiprocessing.get_context("spawn")
# the current process might have CUDA context,
Expand All @@ -156,9 +163,6 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
logger.info(
f"Started engine process with PID {rpc_server_process.pid}")

# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path)

try:
while True:
try:
Expand Down
11 changes: 9 additions & 2 deletions aphrodite/endpoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
from aphrodite.lora.request import LoRARequest
from aphrodite.prompt_adapter.request import PromptAdapterRequest

# Success string used for RPC instructions.
APHRODITE_RPC_SUCCESS_STR = "SUCCESS"
APHRODITE_RPC_HEALTHY_STR = "HEALTHY"
# Timeouts.
APHRODITE_RPC_SERVER_START_TIMEOUT_MS = 1000
APHRODITE_RPC_HEALTH_TIMEOUT_MS = 10000
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
APHRODITE_RPC_SOCKET_LIMIT_CUTOFF = 2000
# HWM is set to Infinity.
APHRODITE_RPC_ZMQ_HWM = 0


@dataclass
Expand All @@ -33,7 +40,7 @@ class RPCUtilityRequest(Enum):
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8
IS_SERVER_HEALTHY = 8


RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
Expand Down
230 changes: 163 additions & 67 deletions aphrodite/endpoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,130 @@
import asyncio
from contextlib import contextmanager
from typing import Any, AsyncGenerator, Optional
from uuid import uuid4

import cloudpickle
import zmq
import zmq.asyncio
from loguru import logger

from aphrodite.common.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.endpoints.openai.rpc import (APHRODITE_RPC_HEALTHY_STR,
APHRODITE_RPC_SUCCESS_STR,
RPC_REQUEST_TYPE, RPCAbortRequest,
RPCGenerateRequest,
RPCUtilityRequest)
from aphrodite.endpoints.openai.rpc import (
APHRODITE_RPC_HEALTH_TIMEOUT_MS, APHRODITE_RPC_SERVER_START_TIMEOUT_MS,
APHRODITE_RPC_SOCKET_LIMIT_CUTOFF, APHRODITE_RPC_SUCCESS_STR,
APHRODITE_RPC_ZMQ_HWM, RPC_REQUEST_TYPE, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from aphrodite.inputs import PromptInputs
from aphrodite.lora.request import LoRARequest
from aphrodite.prompt_adapter.request import PromptAdapterRequest
from aphrodite.transformers_utils.tokenizer_group import (
init_tokenizer_from_configs)

# Time to wait before checking if the server process is alive
SERVER_START_TIMEOUT_MS = 1000
# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"

class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""

def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self.rpc_path = rpc_path
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
if socket_limit < APHRODITE_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests Aphrodite can process. "
"Launch Aphrodite with --disable-frontend-multiprocessing and "
"open a GitHub issue so we can investigate.")
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(APHRODITE_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server = self.context.socket(zmq.constants.ROUTER)
self.from_api_server.set_hwm(APHRODITE_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in Aphrodite w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2

async def run_proxy(self, socket_from, socket_to):
"""Background task that runs a proxy"""
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
identity, msg = await socket_from.recv_multipart()
await socket_to.send_multipart([identity, msg])
if socket_to in events:
identity, msg = await socket_to.recv_multipart()
await socket_from.send_multipart([identity, msg])

async def setup(self):
"""Setup the client before it starts sending server requests."""

# Wait until server is ready.
await self.wait_for_server()
await self._wait_for_server_rpc()
self._errored = False

# Get the configs.
Expand All @@ -51,40 +142,34 @@ async def setup(self):

def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self.from_api_server.close()
self.to_rpc_server.close()
self.context.destroy()

@contextmanager
def socket(self):
# Ensure client sockets are always closed after use

# Connect to RPC socket for Request-Reply pattern,
@contextmanager
def to_proxy_socket(self):
# Connect to the RPCServer via the proxy.
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.set_hwm(APHRODITE_RPC_ZMQ_HWM)
try:
socket.connect(self.rpc_path)
socket.connect(INPROC_PROXY_PATH)
yield socket
finally:
# linger == 0 means discard unsent messages
# when the socket is closed. This is necessary
# because otherwise self.context.destroy() will
# wait for 30 seconds until unsent messages are
# received, which is impossible if the server
# crashed. In the absence of a server crash we
# always expect a response before closing the
# socket anyway.
# Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24
socket.close(linger=0)

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""

with self.socket() as socket:

with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request))
await socket.send_multipart([cloudpickle.dumps(request)])

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
Expand All @@ -93,27 +178,45 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
elif isinstance(data, Exception):
logger.error(error_message)
raise data
else:
raise ValueError(error_message)

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
error_message: str,
timeout: Optional[int] = None,
socket: Optional[zmq.asyncio.Socket] = None):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))

# Await acknowledgement from RPCServer.
response = cloudpickle.loads(await socket.recv())

if not isinstance(response, str) or response != \
APHRODITE_RPC_SUCCESS_STR:
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE,
timeout=None):
await socket.send_multipart([cloudpickle.dumps(request)])
if timeout is not None and await socket.poll(timeout=timeout) == 0:
raise TimeoutError(f"Server didn't reply within {timeout} ms")
return cloudpickle.loads(await socket.recv())

# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request, timeout)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request, timeout)

if not isinstance(
response, str) or response != APHRODITE_RPC_SUCCESS_STR:
if isinstance(response, Exception):
logger.error(error_message)
raise response
raise ValueError(error_message)

return response

async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

Expand All @@ -123,12 +226,13 @@ async def get_decoding_config(self) -> DecodingConfig:
async def get_model_config(self) -> ModelConfig:
return self.model_config

async def wait_for_server(self):
async def _wait_for_server_rpc(self):
"""Wait for the RPCServer to start up."""

await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
error_message="Unable to start RPC Server",
timeout=APHRODITE_RPC_SERVER_START_TIMEOUT_MS)

async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
Expand Down Expand Up @@ -208,7 +312,7 @@ async def generate(

finished = False
try:
with self.socket() as socket:
with self.to_proxy_socket() as socket:

# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
Expand All @@ -227,43 +331,35 @@ async def generate(
request_output = cloudpickle.loads(message)

if isinstance(request_output, Exception):
# On exception, check if the server is still healthy.
# Use this to set the sync `is_running` and `errored`
# properties.
try:
await self.check_health()
except Exception:
self._errored = True
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if not self._errored:
try:
await self.check_health(socket=socket)
except Exception as e:
self._errored = True
logger.exception(repr(e))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output

finished = request_output.finished
yield request_output
finally:
if not finished:
# Request was canceled by the client.
if not finished and not self._errored:
await self.abort(request_id)

async def check_health(self) -> None:
async def check_health(self,
socket: Optional[zmq.asyncio.Socket] = None
) -> None:
"""Raise if unhealthy"""

with self.socket() as socket:

# Ping RPCServer with CHECK_HEALTH request.
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
)

# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = cloudpickle.loads(await socket.recv())

if isinstance(health_message, Exception):
raise health_message

if health_message != APHRODITE_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
f"{health_message}")
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
timeout=APHRODITE_RPC_HEALTH_TIMEOUT_MS,
socket=socket)

async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
Expand Down
Loading

0 comments on commit b5aa110

Please sign in to comment.