From ee3cabd57247ff0f25cb65c0ee46bd35ead8d11a Mon Sep 17 00:00:00 2001 From: Tian Xia Date: Sat, 21 Dec 2024 10:14:37 +0800 Subject: [PATCH] [Serve] Add and adopt least load policy as default poicy. (#4439) * [Serve] Add and adopt least load policy as default poicy. * Docs & smoke tests * error message for different lb policy * add minimal example * fix --- docs/source/serving/sky-serve.rst | 3 + examples/serve/minimal.yaml | 11 ++++ sky/serve/core.py | 14 ++++- sky/serve/load_balancer.py | 12 +++- sky/serve/load_balancing_policies.py | 70 +++++++++++++++++++++-- sky/serve/serve_state.py | 20 +++++-- sky/serve/serve_utils.py | 8 ++- sky/serve/service.py | 1 + sky/serve/service_spec.py | 6 +- tests/skyserve/load_balancer/service.yaml | 1 + tests/skyserve/update/new.yaml | 1 + tests/skyserve/update/old.yaml | 1 + 12 files changed, 131 insertions(+), 17 deletions(-) create mode 100644 examples/serve/minimal.yaml diff --git a/docs/source/serving/sky-serve.rst b/docs/source/serving/sky-serve.rst index 5a1a913b7ea..693102c0550 100644 --- a/docs/source/serving/sky-serve.rst +++ b/docs/source/serving/sky-serve.rst @@ -242,6 +242,9 @@ Under the hood, :code:`sky serve up`: #. Meanwhile, the controller provisions replica VMs which later run the services; #. Once any replica is ready, the requests sent to the Service Endpoint will be distributed to one of the endpoint replicas. +.. note:: + SkyServe uses least load load balancing to distribute the traffic to the replicas. It keeps track of the number of requests each replica has handled and routes the next request to the replica with the least load. + After the controller is provisioned, you'll see the following in :code:`sky serve status` output: .. image:: ../images/sky-serve-status-output-provisioning.png diff --git a/examples/serve/minimal.yaml b/examples/serve/minimal.yaml new file mode 100644 index 00000000000..c925d26f5d1 --- /dev/null +++ b/examples/serve/minimal.yaml @@ -0,0 +1,11 @@ +# An minimal example of a serve application. + +service: + readiness_probe: / + replicas: 1 + +resources: + ports: 8080 + cpus: 2+ + +run: python3 -m http.server 8080 diff --git a/sky/serve/core.py b/sky/serve/core.py index f6f6c53ad7b..561314bcbe0 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -384,6 +384,17 @@ def update( with ux_utils.print_exception_no_traceback(): raise RuntimeError(prompt) + original_lb_policy = service_record['load_balancing_policy'] + assert task.service is not None, 'Service section not found.' + if original_lb_policy != task.service.load_balancing_policy: + logger.warning( + f'{colorama.Fore.YELLOW}Current load balancing policy ' + f'{original_lb_policy!r} is different from the new policy ' + f'{task.service.load_balancing_policy!r}. Updating the load ' + 'balancing policy is not supported yet and it will be ignored. ' + 'The service will continue to use the current load balancing ' + f'policy.{colorama.Style.RESET_ALL}') + with rich_utils.safe_status( ux_utils.spinner_message('Initializing service')): controller_utils.maybe_translate_local_file_mounts_and_sync_up( @@ -581,9 +592,10 @@ def status( 'status': (sky.ServiceStatus) service status, 'controller_port': (Optional[int]) controller port, 'load_balancer_port': (Optional[int]) load balancer port, - 'policy': (Optional[str]) load balancer policy description, + 'policy': (Optional[str]) autoscaling policy description, 'requested_resources_str': (str) str representation of requested resources, + 'load_balancing_policy': (str) load balancing policy name, 'replica_info': (List[Dict[str, Any]]) replica information, } diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index 30697532a22..6b4621569d6 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -45,6 +45,8 @@ def __init__(self, # Use the registry to create the load balancing policy self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make( load_balancing_policy_name) + logger.info('Starting load balancer with policy ' + f'{load_balancing_policy_name}.') self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) # TODO(tian): httpx.Client has a resource limit of 100 max connections @@ -128,6 +130,7 @@ async def _proxy_request_to( encountered if anything goes wrong. """ logger.info(f'Proxy request to {url}') + self._load_balancing_policy.pre_execute_hook(url, request) try: # We defer the get of the client here on purpose, for case when the # replica is ready in `_proxy_with_retries` but refreshed before @@ -147,11 +150,16 @@ async def _proxy_request_to( content=await request.body(), timeout=constants.LB_STREAM_TIMEOUT) proxy_response = await client.send(proxy_request, stream=True) + + async def background_func(): + await proxy_response.aclose() + self._load_balancing_policy.post_execute_hook(url, request) + return fastapi.responses.StreamingResponse( content=proxy_response.aiter_raw(), status_code=proxy_response.status_code, headers=proxy_response.headers, - background=background.BackgroundTask(proxy_response.aclose)) + background=background.BackgroundTask(background_func)) except (httpx.RequestError, httpx.HTTPStatusError) as e: logger.error(f'Error when proxy request to {url}: ' f'{common_utils.format_exception(e)}') @@ -263,7 +271,7 @@ def run_load_balancer(controller_addr: str, parser.add_argument( '--load-balancing-policy', choices=available_policies, - default='round_robin', + default=lb_policies.DEFAULT_LB_POLICY, help=f'The load balancing policy to use. Available policies: ' f'{", ".join(available_policies)}.') args = parser.parse_args() diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index aec6eb01487..4ad69f78943 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -1,7 +1,9 @@ """LoadBalancingPolicy: Policy to select endpoint.""" +import collections import random +import threading import typing -from typing import List, Optional +from typing import Dict, List, Optional from sky import sky_logging @@ -13,6 +15,10 @@ # Define a registry for load balancing policies LB_POLICIES = {} DEFAULT_LB_POLICY = None +# Prior to #4439, the default policy was round_robin. We store the legacy +# default policy here to maintain backwards compatibility. Remove this after +# 2 minor release, i.e., 0.9.0. +LEGACY_DEFAULT_POLICY = 'round_robin' def _request_repr(request: 'fastapi.Request') -> str: @@ -38,11 +44,17 @@ def __init_subclass__(cls, name: str, default: bool = False): DEFAULT_LB_POLICY = name @classmethod - def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy': - """Create a load balancing policy from a name.""" + def make_policy_name(cls, policy_name: Optional[str]) -> str: + """Return the policy name.""" + assert DEFAULT_LB_POLICY is not None, 'No default policy set.' if policy_name is None: - policy_name = DEFAULT_LB_POLICY + return DEFAULT_LB_POLICY + return policy_name + @classmethod + def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy': + """Create a load balancing policy from a name.""" + policy_name = cls.make_policy_name(policy_name) if policy_name not in LB_POLICIES: raise ValueError(f'Unknown load balancing policy: {policy_name}') return LB_POLICIES[policy_name]() @@ -65,8 +77,16 @@ def select_replica(self, request: 'fastapi.Request') -> Optional[str]: def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: raise NotImplementedError + def pre_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + pass + + def post_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + pass + -class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True): +class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin'): """Round-robin load balancing policy.""" def __init__(self) -> None: @@ -90,3 +110,43 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: ready_replica_url = self.ready_replicas[self.index] self.index = (self.index + 1) % len(self.ready_replicas) return ready_replica_url + + +class LeastLoadPolicy(LoadBalancingPolicy, name='least_load', default=True): + """Least load load balancing policy.""" + + def __init__(self) -> None: + super().__init__() + self.load_map: Dict[str, int] = collections.defaultdict(int) + self.lock = threading.Lock() + + def set_ready_replicas(self, ready_replicas: List[str]) -> None: + if set(self.ready_replicas) == set(ready_replicas): + return + with self.lock: + self.ready_replicas = ready_replicas + for r in self.ready_replicas: + if r not in ready_replicas: + del self.load_map[r] + for replica in ready_replicas: + self.load_map[replica] = self.load_map.get(replica, 0) + + def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: + del request # Unused. + if not self.ready_replicas: + return None + with self.lock: + return min(self.ready_replicas, + key=lambda replica: self.load_map.get(replica, 0)) + + def pre_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + del request # Unused. + with self.lock: + self.load_map[replica_url] += 1 + + def post_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + del request # Unused. + with self.lock: + self.load_map[replica_url] -= 1 diff --git a/sky/serve/serve_state.py b/sky/serve/serve_state.py index 333e0138fb4..983e17d00ae 100644 --- a/sky/serve/serve_state.py +++ b/sky/serve/serve_state.py @@ -11,6 +11,7 @@ import colorama from sky.serve import constants +from sky.serve import load_balancing_policies as lb_policies from sky.utils import db_utils if typing.TYPE_CHECKING: @@ -76,6 +77,8 @@ def create_table(cursor: 'sqlite3.Cursor', conn: 'sqlite3.Connection') -> None: db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', 'active_versions', f'TEXT DEFAULT {json.dumps([])!r}') +db_utils.add_column_to_table(_DB.cursor, _DB.conn, 'services', + 'load_balancing_policy', 'TEXT DEFAULT NULL') _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG = 'UNIQUE constraint failed: services.name' @@ -241,7 +244,8 @@ def from_replica_statuses( def add_service(name: str, controller_job_id: int, policy: str, - requested_resources_str: str, status: ServiceStatus) -> bool: + requested_resources_str: str, load_balancing_policy: str, + status: ServiceStatus) -> bool: """Add a service in the database. Returns: @@ -254,10 +258,10 @@ def add_service(name: str, controller_job_id: int, policy: str, """\ INSERT INTO services (name, controller_job_id, status, policy, - requested_resources_str) - VALUES (?, ?, ?, ?, ?)""", + requested_resources_str, load_balancing_policy) + VALUES (?, ?, ?, ?, ?, ?)""", (name, controller_job_id, status.value, policy, - requested_resources_str)) + requested_resources_str, load_balancing_policy)) except sqlite3.IntegrityError as e: if str(e) != _UNIQUE_CONSTRAINT_FAILED_ERROR_MSG: @@ -324,7 +328,12 @@ def set_service_load_balancer_port(service_name: str, def _get_service_from_row(row) -> Dict[str, Any]: (current_version, name, controller_job_id, controller_port, load_balancer_port, status, uptime, policy, _, _, requested_resources_str, - _, active_versions) = row[:13] + _, active_versions, load_balancing_policy) = row[:14] + if load_balancing_policy is None: + # This entry in database was added in #4439, and it will always be set + # to a str value. If it is None, it means it is an legacy entry and is + # using the legacy default policy. + load_balancing_policy = lb_policies.LEGACY_DEFAULT_POLICY return { 'name': name, 'controller_job_id': controller_job_id, @@ -341,6 +350,7 @@ def _get_service_from_row(row) -> Dict[str, Any]: # integers in json format. This is mainly for display purpose. 'active_versions': json.loads(active_versions), 'requested_resources_str': requested_resources_str, + 'load_balancing_policy': load_balancing_policy, } diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 6ab932f278a..7e665929d66 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -811,7 +811,9 @@ def format_service_table(service_records: List[Dict[str, Any]], 'NAME', 'VERSION', 'UPTIME', 'STATUS', 'REPLICAS', 'ENDPOINT' ] if show_all: - service_columns.extend(['POLICY', 'REQUESTED_RESOURCES']) + service_columns.extend([ + 'AUTOSCALING_POLICY', 'LOAD_BALANCING_POLICY', 'REQUESTED_RESOURCES' + ]) service_table = log_utils.create_table(service_columns) replica_infos = [] @@ -832,6 +834,7 @@ def format_service_table(service_records: List[Dict[str, Any]], endpoint = get_endpoint(record) policy = record['policy'] requested_resources_str = record['requested_resources_str'] + load_balancing_policy = record['load_balancing_policy'] service_values = [ service_name, @@ -842,7 +845,8 @@ def format_service_table(service_records: List[Dict[str, Any]], endpoint, ] if show_all: - service_values.extend([policy, requested_resources_str]) + service_values.extend( + [policy, load_balancing_policy, requested_resources_str]) service_table.add_row(service_values) replica_table = _format_replica_table(replica_infos, show_all) diff --git a/sky/serve/service.py b/sky/serve/service.py index 0a1c7f34766..dbfc57b22bf 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -150,6 +150,7 @@ def _start(service_name: str, tmp_task_yaml: str, job_id: int): controller_job_id=job_id, policy=service_spec.autoscaling_policy_str(), requested_resources_str=backend_utils.get_task_resources_str(task), + load_balancing_policy=service_spec.load_balancing_policy, status=serve_state.ServiceStatus.CONTROLLER_INIT) # Directly throw an error here. See sky/serve/api.py::up # for more details. diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 000eed139f1..fbbca5bc0dd 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -8,6 +8,7 @@ from sky import serve from sky.serve import constants +from sky.serve import load_balancing_policies as lb_policies from sky.utils import common_utils from sky.utils import schemas from sky.utils import ux_utils @@ -327,5 +328,6 @@ def use_ondemand_fallback(self) -> bool: return self._use_ondemand_fallback @property - def load_balancing_policy(self) -> Optional[str]: - return self._load_balancing_policy + def load_balancing_policy(self) -> str: + return lb_policies.LoadBalancingPolicy.make_policy_name( + self._load_balancing_policy) diff --git a/tests/skyserve/load_balancer/service.yaml b/tests/skyserve/load_balancer/service.yaml index 742b8efd2f4..232136d4a61 100644 --- a/tests/skyserve/load_balancer/service.yaml +++ b/tests/skyserve/load_balancer/service.yaml @@ -5,6 +5,7 @@ service: initial_delay_seconds: 180 replica_policy: min_replicas: 3 + load_balancing_policy: round_robin resources: ports: 8080 diff --git a/tests/skyserve/update/new.yaml b/tests/skyserve/update/new.yaml index 2c9cebd0cb5..5e5d853e09d 100644 --- a/tests/skyserve/update/new.yaml +++ b/tests/skyserve/update/new.yaml @@ -3,6 +3,7 @@ service: path: /health initial_delay_seconds: 100 replicas: 2 + load_balancing_policy: round_robin resources: ports: 8081 diff --git a/tests/skyserve/update/old.yaml b/tests/skyserve/update/old.yaml index 4b99cb92e8c..4cb19b8327b 100644 --- a/tests/skyserve/update/old.yaml +++ b/tests/skyserve/update/old.yaml @@ -3,6 +3,7 @@ service: path: /health initial_delay_seconds: 20 replicas: 2 + load_balancing_policy: round_robin resources: ports: 8080