From d308dfce2e83708a5f37e1230bc8d52a6b235263 Mon Sep 17 00:00:00 2001 From: cblmemo Date: Sun, 5 May 2024 18:48:34 -0700 Subject: [PATCH] apply suggestions from code review --- sky/serve/constants.py | 3 +- sky/serve/load_balancer.py | 109 +++++++++++++-------------- sky/serve/load_balancing_policies.py | 5 +- sky/utils/common_utils.py | 5 +- 4 files changed, 62 insertions(+), 60 deletions(-) diff --git a/sky/serve/constants.py b/sky/serve/constants.py index ece2838be62..2ac8f9169ba 100644 --- a/sky/serve/constants.py +++ b/sky/serve/constants.py @@ -28,7 +28,8 @@ # The timeout in seconds for load balancer to wait for a response from replica. # Large LLMs like Llama2-70b is able to process the request within ~30 seconds. -# We set the timeout to 120s to be safe. +# We set the timeout to 120s to be safe. For reference, FastChat uses 100s: +# https://github.com/lm-sys/FastChat/blob/f2e6ca964af7ad0585cadcf16ab98e57297e2133/fastchat/constants.py#L39 # pylint: disable=line-too-long # TODO(tian): Expose this option to users in yaml file. LB_STREAM_TIMEOUT = 120 diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index a0818014808..ba5747301e3 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -3,11 +3,12 @@ import logging import threading import time -from typing import Optional +from typing import Dict, Union import fastapi import httpx import requests +from starlette import background import uvicorn from sky import sky_logging @@ -41,6 +42,7 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None: lb_policies.RoundRobinPolicy()) self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) + self._client_pool: Dict[str, httpx.AsyncClient] = dict() def _sync_with_controller(self): """Sync with controller periodically. @@ -71,71 +73,55 @@ def _sync_with_controller(self): ready_replica_urls = response.json().get( 'ready_replica_urls') except requests.RequestException as e: - print(f'An error occurred: {e}') + logger.error(f'An error occurred: {e}') else: logger.info(f'Available Replica URLs: {ready_replica_urls}') self._load_balancing_policy.set_ready_replicas( ready_replica_urls) + for replica_url in ready_replica_urls: + if replica_url not in self._client_pool: + # TODO(tian): Support HTTPS. + self._client_pool[replica_url] = httpx.AsyncClient( + base_url=f'http://{replica_url}') + closed_urls = [] + for replica_url, client in self._client_pool.items(): + if replica_url not in ready_replica_urls: + asyncio.run(client.aclose()) + closed_urls.append(replica_url) + for replica_url in closed_urls: + del self._client_pool[replica_url] time.sleep(constants.LB_CONTROLLER_SYNC_INTERVAL_SECONDS) async def _proxy_request_to( - self, url: str, - request: fastapi.Request) -> Optional[fastapi.responses.Response]: + self, url: str, request: fastapi.Request + ) -> Union[fastapi.responses.Response, Exception]: """Proxy the request to the specified URL. Returns: - The response from the endpoint replica. None if anything goes wrong. + The response from the endpoint replica. Return the exception + encountered if anything goes wrong. """ - method = request.method - headers = {key: value for key, value in request.headers.items()} - body = await request.body() - path = f'http://{url}{request.url.path}' - logger.info(f'Proxy request to {path}') + logger.info(f'Proxy request to {url}') try: - - async def stream_response(): - """Construct the response stream. - - Yields: - The response status code and headers first. Then the - response body. - """ - async with httpx.AsyncClient() as client: - async with client.stream( - method, - path, - headers=headers, - content=body, - timeout=constants.LB_STREAM_TIMEOUT) as response: - response.raise_for_status() - # Hacky. We need to construct the async client within - # the async generator to avoid the client being closed - # before the response is consumed. However, we still - # need the response status code and headers to construct - # the StreamingResponse, which is only available after - # the client is constructed. We yield them first here. - # TODO(tian): Investigate a way to not directly yielding - # the response status code and headers. - yield response.status_code - yield dict(response.headers) - try: - async for chunk in response.aiter_bytes(): - yield chunk - except Exception as e: # pylint: disable=broad-except - yield f'Error: {str(e)}' - finally: - await response.aclose() - - content = stream_response() - status_code = await content.__anext__() - headers = await content.__anext__() - return fastapi.responses.StreamingResponse(content=content, - status_code=status_code, - headers=headers) + client = self._client_pool[url] + worker_url = httpx.URL(path=request.url.path, + query=request.url.query.encode('utf-8')) + proxy_request = client.build_request( + request.method, + worker_url, + headers=request.headers.raw, + content=await request.body(), + timeout=constants.LB_STREAM_TIMEOUT) + proxy_response = await client.send(proxy_request, stream=True) + return fastapi.responses.StreamingResponse( + content=proxy_response.aiter_raw(), + status_code=proxy_response.status_code, + headers=proxy_response.headers, + background=background.BackgroundTask(proxy_response.aclose)) except (httpx.RequestError, httpx.HTTPStatusError) as e: - logger.error(f'Error when proxy request to {path}: ' + logger.error(f'Error when proxy request to {url}: ' f'{common_utils.format_exception(e)}') - return None + return e async def _proxy_with_retries( self, request: fastapi.Request) -> fastapi.responses.Response: @@ -152,20 +138,29 @@ async def _proxy_with_retries( request) if ready_replica_url is None: raise fastapi.HTTPException( + # 503 means that the server is currently + # unable to handle the incoming requests. status_code=503, detail='No ready replicas. ' 'Use "sky serve status [SERVICE_NAME]" ' 'to check the replica status.') - response = await self._proxy_request_to(ready_replica_url, request) - if response is not None: - return response + response_or_exception = await self._proxy_request_to( + ready_replica_url, request) + if not isinstance(response_or_exception, Exception): + return response_or_exception # TODO(tian): Fail fast for errors like 404 not found. if retry_cnt == constants.LB_MAX_RETRY: + exception = common_utils.format_exception( + response_or_exception, + use_bracket=True, + brighten_error_class=False) raise fastapi.HTTPException( + # 500 means internal server error. status_code=500, detail=f'Max retries {constants.LB_MAX_RETRY} exceeded. ' - 'Please use "sky serve logs [SERVICE_NAME] ' - '--load-balancer" for more information.') + f'Last error encountered: {exception}. Please use ' + '"sky serve logs [SERVICE_NAME] --load-balancer" ' + 'for more information.') current_backoff = backoff.current_backoff() logger.error(f'Retry in {current_backoff} seconds.') await asyncio.sleep(current_backoff) diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index 863561f7b6b..34c1fa4249b 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -22,6 +22,9 @@ def _request_repr(request: 'fastapi.Request') -> str: class LoadBalancingPolicy: """Abstract class for load balancing policies.""" + def __init__(self) -> None: + self.ready_replicas: List[str] = [] + def set_ready_replicas(self, ready_replicas: List[str]) -> None: raise NotImplementedError @@ -45,7 +48,7 @@ class RoundRobinPolicy(LoadBalancingPolicy): """Round-robin load balancing policy.""" def __init__(self) -> None: - self.ready_replicas: List[str] = [] + super().__init__() self.index = 0 def set_ready_replicas(self, ready_replicas: List[str]) -> None: diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 2abefc6fea0..914039a8315 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -456,7 +456,8 @@ def class_fullname(cls, skip_builtins: bool = True): def format_exception(e: Union[Exception, SystemExit, KeyboardInterrupt], - use_bracket: bool = False) -> str: + use_bracket: bool = False, + brighten_error_class: bool = True) -> str: """Format an exception to a string. Args: @@ -467,6 +468,8 @@ def format_exception(e: Union[Exception, SystemExit, KeyboardInterrupt], """ bright = colorama.Style.BRIGHT reset = colorama.Style.RESET_ALL + if not brighten_error_class: + bright, reset = '', '' if use_bracket: return f'{bright}[{class_fullname(e.__class__)}]{reset} {e}' return f'{bright}{class_fullname(e.__class__)}:{reset} {e}'