Skip to content

Commit

Permalink
[SkyServe] Rolling Update (#2935)
Browse files Browse the repository at this point in the history
* rewrite skyserve update

* remove sigint

* delete storage buckets of older versions

* revert cleanup storage

* clean up storage from past versions

* add a todo comment

* add service field backward compatibility

* add service field backward compatibility

* remove stale comments

* update comment

* pr review

* add versions to db

* refactor

* fix weird serve-update-new

* del

* merge master

* pr review

* clip target_num_replicas

* not expose mixed_replica_versions to user

* address pr

* fix bug

* fix bug

* fix

* bug fix

* update comments

* if statements

* move add_column_to_table place

* move graceful termination to _terminate_replica

* move version checking code to _refresh_process_pool

* code reviews

* address code reviews

* version updating outside the for loop

* code review

* clean up storage

* rename logger info

* code change

* update service atomically

* fix bug

* fix

* update variable names

* atomic update to increment version

* add skyserve update smoke test
  • Loading branch information
MaoZiming authored Jan 24, 2024
1 parent e1f5523 commit 3765f03
Show file tree
Hide file tree
Showing 14 changed files with 728 additions and 103 deletions.
136 changes: 95 additions & 41 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4341,6 +4341,52 @@ def serve():
pass


def _generate_task_with_service(service_yaml_args: List[str],
not_supported_cmd: str) -> sky.Task:
"""Generate a task with service section from a service YAML file."""
is_yaml, _ = _check_yaml(''.join(service_yaml_args))
if not is_yaml:
raise click.UsageError('SERVICE_YAML must be a valid YAML file.')
# We keep nargs=-1 in service_yaml argument to reuse this function.
task = _make_task_or_dag_from_entrypoint_with_overrides(
service_yaml_args, entrypoint_name='Service')
if isinstance(task, sky.Dag):
raise click.UsageError(
_DAG_NOT_SUPPORTED_MESSAGE.format(command=not_supported_cmd))

if task.service is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Service section not found in the YAML file. '
'To fix, add a valid `service` field.')
service_port: Optional[int] = None
for requested_resources in list(task.resources):
if requested_resources.ports is None or len(
requested_resources.ports) != 1:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Must only specify one port in resources. Each replica '
'will use the port specified as application ingress port.')
service_port_str = requested_resources.ports[0]
if not service_port_str.isdigit():
# For the case when the user specified a port range like 10000-10010
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Port {service_port_str!r} is not a valid '
'port number. Please specify a single port '
f'instead. Got: {service_port_str!r}')
# We request all the replicas using the same port for now, but it
# should be fine to allow different replicas to use different ports
# in the future.
resource_port = int(service_port_str)
if service_port is None:
service_port = resource_port
if service_port != resource_port:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Got multiple ports: {service_port} and '
f'{resource_port} in different resources. '
'Please specify single port instead.')
return task


@serve.command('up', cls=_DocumentedCodeCommand)
@click.argument('service_yaml',
required=True,
Expand Down Expand Up @@ -4396,47 +4442,8 @@ def serve_up(
if service_name is None:
service_name = serve_lib.generate_service_name()

is_yaml, _ = _check_yaml(''.join(service_yaml))
if not is_yaml:
raise click.UsageError('SERVICE_YAML must be a valid YAML file.')
# We keep nargs=-1 in service_yaml argument to reuse this function.
task = _make_task_or_dag_from_entrypoint_with_overrides(
service_yaml, entrypoint_name='Service')
if isinstance(task, sky.Dag):
raise click.UsageError(
_DAG_NOT_SUPPORTED_MESSAGE.format(command='sky serve up'))

if task.service is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Service section not found in the YAML file. '
'To fix, add a valid `service` field.')
service_port: Optional[int] = None
for requested_resources in list(task.resources):
if requested_resources.ports is None or len(
requested_resources.ports) != 1:
with ux_utils.print_exception_no_traceback():
raise ValueError(
'Must only specify one port in resources. Each replica '
'will use the port specified as application ingress port.')
service_port_str = requested_resources.ports[0]
if not service_port_str.isdigit():
# For the case when the user specified a port range like 10000-10010
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Port {service_port_str!r} is not a valid '
'port number. Please specify a single port '
f'instead. Got: {service_port_str!r}')
# We request all the replicas using the same port for now, but it
# should be fine to allow different replicas to use different ports
# in the future.
resource_port = int(service_port_str)
if service_port is None:
service_port = resource_port
if service_port != resource_port:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Got multiple ports: {service_port} and '
f'{resource_port} in different resources. '
'Please specify single port instead.')

task = _generate_task_with_service(service_yaml_args=service_yaml,
not_supported_cmd='sky serve up')
click.secho('Service Spec:', fg='cyan')
click.echo(task.service)

Expand All @@ -4454,6 +4461,53 @@ def serve_up(
serve_lib.up(task, service_name)


# TODO(MaoZiming): Update Doc.
# TODO(MaoZiming): Expose mix replica traffic option to user.
# Currently, we do not mix traffic from old and new replicas.
@serve.command('update', cls=_DocumentedCodeCommand)
@click.argument('service_name', required=True, type=str)
@click.argument('service_yaml',
required=True,
type=str,
nargs=-1,
**_get_shell_complete_args(_complete_file_name))
@click.option('--yes',
'-y',
is_flag=True,
default=False,
required=False,
help='Skip confirmation prompt.')
def serve_update(service_name: str, service_yaml: List[str], yes: bool):
"""Update a SkyServe service.
service_yaml must point to a valid YAML file.
Example:
.. code-block:: bash
sky serve update sky-service-16aa new_service.yaml
"""
task = _generate_task_with_service(service_yaml_args=service_yaml,
not_supported_cmd='sky serve update')
click.secho('Service Spec:', fg='cyan')
click.echo(task.service)

click.secho('New replica will use the following resources (estimated):',
fg='cyan')
with sky.Dag() as dag:
dag.add(task)
sky.optimize(dag)

if not yes:
click.confirm(f'Updating service {service_name!r}. Proceed?',
default=True,
abort=True,
show_default=True)

serve_lib.update(task, service_name)


@serve.command('status', cls=_DocumentedCodeCommand)
@click.option('--all',
'-a',
Expand Down
2 changes: 2 additions & 0 deletions sky/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import os

from sky.serve.constants import ENDPOINT_PROBE_INTERVAL_SECONDS
from sky.serve.constants import INITIAL_VERSION
from sky.serve.constants import LB_CONTROLLER_SYNC_INTERVAL_SECONDS
from sky.serve.constants import SERVICES_TASK_CPU_DEMAND
from sky.serve.constants import SKYSERVE_METADATA_DIR
from sky.serve.core import down
from sky.serve.core import status
from sky.serve.core import tail_logs
from sky.serve.core import up
from sky.serve.core import update
from sky.serve.serve_state import ReplicaStatus
from sky.serve.serve_state import ServiceStatus
from sky.serve.serve_utils import format_service_table
Expand Down
96 changes: 77 additions & 19 deletions sky/serve/autoscalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,28 @@ def __init__(self, spec: 'service_spec.SkyServiceSpec') -> None:
max_replicas: Maximum number of replicas. Default to fixed
number of replicas, i.e. min_replicas == max_replicas.
target_num_replicas: Target number of replicas output by autoscaler.
latest_version: latest version of the service.
"""
self.min_replicas: int = spec.min_replicas
self.max_replicas: int = spec.max_replicas or spec.min_replicas
self.max_replicas: int = (spec.max_replicas if spec.max_replicas
is not None else spec.min_replicas)
# Target number of replicas is initialized to min replicas.
self.target_num_replicas: int = spec.min_replicas
self.latest_version: int = constants.INITIAL_VERSION

def update_version(self, version: int,
spec: 'service_spec.SkyServiceSpec') -> None:
if version <= self.latest_version:
logger.error(f'Invalid version: {version}, '
f'latest version: {self.latest_version}')
return
self.latest_version = version
self.min_nodes = spec.min_replicas
self.max_nodes = (spec.max_replicas if spec.max_replicas is not None
else spec.min_replicas)
# Reclip self.target_num_replicas with new min and max replicas.
self.target_num_replicas = max(
self.min_replicas, min(self.max_replicas, self.target_num_replicas))

def collect_request_information(
self, request_aggregator_info: Dict[str, Any]) -> None:
Expand All @@ -89,8 +107,7 @@ class RequestRateAutoscaler(Autoscaler):
the threshold.
"""

def __init__(self, spec: 'service_spec.SkyServiceSpec',
qps_window_size: int) -> None:
def __init__(self, spec: 'service_spec.SkyServiceSpec') -> None:
"""Initialize the request rate autoscaler.
Variables:
Expand All @@ -101,11 +118,12 @@ def __init__(self, spec: 'service_spec.SkyServiceSpec',
downscale_counter: counter for downscale number of replicas.
scale_up_consecutive_periods: period for scaling up.
scale_down_consecutive_periods: period for scaling down.
bootstrap_done: whether bootstrap is done.
"""
super().__init__(spec)
self.target_qps_per_replica: Optional[
float] = spec.target_qps_per_replica
self.qps_window_size: int = qps_window_size
self.qps_window_size: int = constants.AUTOSCALER_QPS_WINDOW_SIZE_SECONDS
self.request_timestamps: List[float] = []
self.upscale_counter: int = 0
self.downscale_counter: int = 0
Expand All @@ -122,11 +140,27 @@ def __init__(self, spec: 'service_spec.SkyServiceSpec',
self.scale_down_consecutive_periods: int = int(
downscale_delay_seconds /
constants.AUTOSCALER_DEFAULT_DECISION_INTERVAL_SECONDS)
# Target number of replicas is initialized to min replicas.
# TODO(MaoZiming): add init replica numbers in SkyServe spec.
self.target_num_replicas: int = spec.min_replicas

self.bootstrap_done: bool = False

def update_version(self, version: int,
spec: 'service_spec.SkyServiceSpec') -> None:
super().update_version(version, spec)
self.target_qps_per_replica = spec.target_qps_per_replica
upscale_delay_seconds = (
spec.upscale_delay_seconds if spec.upscale_delay_seconds is not None
else constants.AUTOSCALER_DEFAULT_UPSCALE_DELAY_SECONDS)
self.scale_up_consecutive_periods = int(
upscale_delay_seconds /
constants.AUTOSCALER_DEFAULT_DECISION_INTERVAL_SECONDS)
downscale_delay_seconds = (
spec.downscale_delay_seconds
if spec.downscale_delay_seconds is not None else
constants.AUTOSCALER_DEFAULT_DOWNSCALE_DELAY_SECONDS)
self.scale_down_consecutive_periods = int(
downscale_delay_seconds /
constants.AUTOSCALER_DEFAULT_DECISION_INTERVAL_SECONDS)

def collect_request_information(
self, request_aggregator_info: Dict[str, Any]) -> None:
"""Collect request information from aggregator for autoscaling.
Expand Down Expand Up @@ -201,10 +235,18 @@ def evaluate_scaling(
override dict. Active migration could require returning both SCALE_UP
and SCALE_DOWN.
"""
launched_replica_infos = [
info for info in replica_infos if info.is_launched
]
num_launched_replicas = len(launched_replica_infos)
provisioning_and_launched_new_replica: List[
'replica_managers.ReplicaInfo'] = []
ready_new_replica: List['replica_managers.ReplicaInfo'] = []
old_replicas: List['replica_managers.ReplicaInfo'] = []
for info in replica_infos:
if info.version == self.latest_version:
if info.is_launched:
provisioning_and_launched_new_replica.append(info)
if info.is_ready:
ready_new_replica.append(info)
else:
old_replicas.append(info)

self.target_num_replicas = self._get_desired_num_replicas()
logger.info(
Expand All @@ -213,7 +255,8 @@ def evaluate_scaling(
f'{self.scale_up_consecutive_periods}, '
f'Downscale counter: {self.downscale_counter}/'
f'{self.scale_down_consecutive_periods} '
f'Number of launched replicas: {num_launched_replicas}')
'Number of launched latest replicas: '
f'{len(provisioning_and_launched_new_replica)}')

scaling_options = []
all_replica_ids_to_scale_down: List[int] = []
Expand All @@ -222,25 +265,40 @@ def _get_replica_ids_to_scale_down(num_limit: int) -> List[int]:

status_order = serve_state.ReplicaStatus.scale_down_decision_order()
launched_replica_infos_sorted = sorted(
launched_replica_infos,
provisioning_and_launched_new_replica,
key=lambda info: status_order.index(info.status)
if info.status in status_order else len(status_order))

return [info.replica_id for info in launched_replica_infos_sorted
][:num_limit]

if num_launched_replicas < self.target_num_replicas:
num_replicas_to_scale_up = (self.target_num_replicas -
num_launched_replicas)
# Case 1. Once there is min_replicas number of
# ready new replicas, we will direct all traffic to them,
# we can scale down all old replicas.
if len(ready_new_replica) >= self.min_replicas:
for info in old_replicas:
all_replica_ids_to_scale_down.append(info.replica_id)

# Case 2. when provisioning_and_launched_new_replica is less
# than target_num_replicas, we always scale up new replicas.
if len(provisioning_and_launched_new_replica
) < self.target_num_replicas:
num_replicas_to_scale_up = (
self.target_num_replicas -
len(provisioning_and_launched_new_replica))

for _ in range(num_replicas_to_scale_up):
scaling_options.append(
AutoscalerDecision(AutoscalerDecisionOperator.SCALE_UP,
target=None))

elif num_launched_replicas > self.target_num_replicas:
num_replicas_to_scale_down = (num_launched_replicas -
self.target_num_replicas)
# Case 3: when provisioning_and_launched_new_replica is more
# than target_num_replicas, we scale down new replicas.
if len(provisioning_and_launched_new_replica
) > self.target_num_replicas:
num_replicas_to_scale_down = (
len(provisioning_and_launched_new_replica) -
self.target_num_replicas)
all_replica_ids_to_scale_down.extend(
_get_replica_ids_to_scale_down(
num_limit=num_replicas_to_scale_down))
Expand Down
3 changes: 3 additions & 0 deletions sky/serve/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@
CONTROLLER_PORT_START = 20001
LOAD_BALANCER_PORT_START = 30001
LOAD_BALANCER_PORT_RANGE = '30001-30100'

# Initial version of service.
INITIAL_VERSION = 1
Loading

0 comments on commit 3765f03

Please sign in to comment.