Skip to content

Commit

Permalink
[Serve] Support manually terminating a replica and with purge option (#…
Browse files Browse the repository at this point in the history
…4032)

* define replica id param in cli

* create endpoint on controller

* call controller endpoint to scale down replica

* add classmethod decorator

* add handler methods for readability in cli

* update docstr and error msg, and inline in cli

* update log and return err msg

* add docstr, catch and reraise err, add stopped and nonexistent message

* inline constant to avoid circular import

* fix error statement and return encoded str

* add purge feature

* add purge replica usage in docstr

* use .get to handle unexpected packages

* fix: diff terminate replica when failed/purging or not

* fix: stay up to date for `is_controller_accessible`

* revert

* up to date with current APIs

* error handling

* when purged remove record in the main loop

* refactor due to reviewer's suggestions

* combine functions

* fix: terminate the healthy replica even with purge option

* remove abbr

* Update sky/serve/core.py

Co-authored-by: Tian Xia <[email protected]>

* Update sky/serve/core.py

Co-authored-by: Tian Xia <[email protected]>

* Update sky/serve/controller.py

Co-authored-by: Tian Xia <[email protected]>

* Update sky/serve/controller.py

Co-authored-by: Tian Xia <[email protected]>

* Update sky/cli.py

Co-authored-by: Tian Xia <[email protected]>

* got services hint

* check if not yes in the outside if branch

* fix some output messages

* Update sky/serve/core.py

Co-authored-by: Tian Xia <[email protected]>

* set conflict status code for already scheduled termination

* combine purge and normal terminating down branch together

* bump version

* global exception handler to render a json response with error messages

* fix: use responses.JSONResponse for dict serialize

* error messages for old controller

* fix: check version mismatch in generated code

* revert mistakenly change update_service

* refine already in terminating message

* fix: branch code workaround in cls.build

* wording

Co-authored-by: Tian Xia <[email protected]>

* refactor due to reviewer's comments

* fix use ux_utils

Co-authored-by: Tian Xia <[email protected]>

* add changelog as comments

* fix messages

* edit the message for mismatch error

Co-authored-by: Tian Xia <[email protected]>

* no traceback when raising in `terminate_replica`

* messages decode

* Apply suggestions from code review

Co-authored-by: Tian Xia <[email protected]>

* format

* forma

* Empty commit

---------

Co-authored-by: David Tran <[email protected]>
Co-authored-by: David Tran <[email protected]>
Co-authored-by: Tian Xia <[email protected]>
  • Loading branch information
4 people authored Oct 19, 2024
1 parent 9201def commit c6ae536
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 18 deletions.
58 changes: 46 additions & 12 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4380,9 +4380,14 @@ def serve_status(all: bool, endpoint: bool, service_names: List[str]):
default=False,
required=False,
help='Skip confirmation prompt.')
@click.option('--replica-id',
default=None,
type=int,
help='Tear down a given replica')
# pylint: disable=redefined-builtin
def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
"""Teardown service(s).
def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool,
replica_id: Optional[int]):
"""Teardown service(s) or a replica.
SERVICE_NAMES is the name of the service (or glob pattern) to tear down. If
both SERVICE_NAMES and ``--all`` are supplied, the latter takes precedence.
Expand All @@ -4408,6 +4413,12 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
\b
# Forcefully tear down a service in failed status.
sky serve down failed-service --purge
\b
# Tear down a specific replica
sky serve down my-service --replica-id 1
\b
# Forcefully tear down a specific replica, even in failed status.
sky serve down my-service --replica-id 1 --purge
"""
if sum([len(service_names) > 0, all]) != 1:
argument_str = f'SERVICE_NAMES={",".join(service_names)}' if len(
Expand All @@ -4417,22 +4428,45 @@ def serve_down(service_names: List[str], all: bool, purge: bool, yes: bool):
'Can only specify one of SERVICE_NAMES or --all. '
f'Provided {argument_str!r}.')

replica_id_is_defined = replica_id is not None
if replica_id_is_defined:
if len(service_names) != 1:
service_names_str = ', '.join(service_names)
raise click.UsageError(f'The --replica-id option can only be used '
f'with a single service name. Got: '
f'{service_names_str}.')
if all:
raise click.UsageError('The --replica-id option cannot be used '
'with the --all option.')

backend_utils.is_controller_accessible(
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message='All services should have been terminated.',
exit_if_not_accessible=True)

if not yes:
quoted_service_names = [f'{name!r}' for name in service_names]
service_identity_str = f'service(s) {", ".join(quoted_service_names)}'
if all:
service_identity_str = 'all services'
click.confirm(f'Terminating {service_identity_str}. Proceed?',
default=True,
abort=True,
show_default=True)

serve_lib.down(service_names=service_names, all=all, purge=purge)
if replica_id_is_defined:
click.confirm(
f'Terminating replica ID {replica_id} in '
f'{service_names[0]!r}. Proceed?',
default=True,
abort=True,
show_default=True)
else:
quoted_service_names = [f'{name!r}' for name in service_names]
service_identity_str = (f'service(s) '
f'{", ".join(quoted_service_names)}')
if all:
service_identity_str = 'all services'
click.confirm(f'Terminating {service_identity_str}. Proceed?',
default=True,
abort=True,
show_default=True)

if replica_id_is_defined:
serve_lib.terminate_replica(service_names[0], replica_id, purge)
else:
serve_lib.down(service_names=service_names, all=all, purge=purge)


@serve.command('logs', cls=_DocumentedCodeCommand)
Expand Down
2 changes: 2 additions & 0 deletions sky/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sky.serve.core import down
from sky.serve.core import status
from sky.serve.core import tail_logs
from sky.serve.core import terminate_replica
from sky.serve.core import up
from sky.serve.core import update
from sky.serve.serve_state import ReplicaStatus
Expand Down Expand Up @@ -42,6 +43,7 @@
'SKY_SERVE_CONTROLLER_NAME',
'SKYSERVE_METADATA_DIR',
'status',
'terminate_replica',
'tail_logs',
'up',
'update',
Expand Down
9 changes: 8 additions & 1 deletion sky/serve/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,11 @@
# change for the serve_utils.ServeCodeGen, we need to bump this version, so that
# the user can be notified to update their SkyPilot serve version on the remote
# cluster.
SERVE_VERSION = 1
# Changelog:
# v1.0 - Introduce rolling update.
# v2.0 - Added template-replica feature.
SERVE_VERSION = 2

TERMINATE_REPLICA_VERSION_MISMATCH_ERROR = (
'The version of service is outdated and does not support manually '
'terminating replicas. Please terminate the service and spin up again.')
70 changes: 70 additions & 0 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import traceback
from typing import Any, Dict, List

import colorama
import fastapi
from fastapi import responses
import uvicorn
Expand Down Expand Up @@ -157,6 +158,75 @@ async def update_service(request: fastapi.Request) -> fastapi.Response:
return responses.JSONResponse(content={'message': 'Error'},
status_code=500)

@self._app.post('/controller/terminate_replica')
async def terminate_replica(
request: fastapi.Request) -> fastapi.Response:
request_data = await request.json()
replica_id = request_data['replica_id']
assert isinstance(replica_id,
int), 'Error: replica ID must be an integer.'
purge = request_data['purge']
assert isinstance(purge, bool), 'Error: purge must be a boolean.'
replica_info = serve_state.get_replica_info_from_id(
self._service_name, replica_id)
assert replica_info is not None, (f'Error: replica '
f'{replica_id} does not exist.')
replica_status = replica_info.status

if replica_status == serve_state.ReplicaStatus.SHUTTING_DOWN:
return responses.JSONResponse(
status_code=409,
content={
'message':
f'Replica {replica_id} of service '
f'{self._service_name!r} is already in the process '
f'of terminating. Skip terminating now.'
})

if (replica_status in serve_state.ReplicaStatus.failed_statuses()
and not purge):
return responses.JSONResponse(
status_code=409,
content={
'message': f'{colorama.Fore.YELLOW}Replica '
f'{replica_id} of service '
f'{self._service_name!r} is in failed '
f'status ({replica_info.status}). '
f'Skipping its termination as it could '
f'lead to a resource leak. '
f'(Use `sky serve down '
f'{self._service_name!r} --replica-id '
f'{replica_id} --purge` to '
'forcefully terminate the replica.)'
f'{colorama.Style.RESET_ALL}'
})

self._replica_manager.scale_down(replica_id, purge=purge)

action = 'terminated' if not purge else 'purged'
message = (f'{colorama.Fore.GREEN}Replica {replica_id} of service '
f'{self._service_name!r} is scheduled to be '
f'{action}.{colorama.Style.RESET_ALL}\n'
f'Please use {ux_utils.BOLD}sky serve status '
f'{self._service_name}{ux_utils.RESET_BOLD} '
f'to check the latest status.')
return responses.JSONResponse(status_code=200,
content={'message': message})

@self._app.exception_handler(Exception)
async def validation_exception_handler(
request: fastapi.Request, exc: Exception) -> fastapi.Response:
with ux_utils.enable_traceback():
logger.error(f'Error in controller: {exc!r}')
return responses.JSONResponse(
status_code=500,
content={
'message':
(f'Failed method {request.method} at URL {request.url}.'
f' Exception message is {exc!r}.')
},
)

threading.Thread(target=self._run_autoscaler).start()

logger.info('SkyServe Controller started on '
Expand Down
47 changes: 47 additions & 0 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,53 @@ def down(
sky_logging.print(stdout)


@usage_lib.entrypoint
def terminate_replica(service_name: str, replica_id: int, purge: bool) -> None:
"""Tear down a specific replica for the given service.
Args:
service_name: Name of the service.
replica_id: ID of replica to terminate.
purge: Whether to terminate replicas in a failed status. These replicas
may lead to resource leaks, so we require the user to explicitly
specify this flag to make sure they are aware of this potential
resource leak.
Raises:
sky.exceptions.ClusterNotUpError: if the sky sere controller is not up.
RuntimeError: if failed to terminate the replica.
"""
handle = backend_utils.is_controller_accessible(
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
stopped_message=
'No service is running now. Please spin up a service first.',
non_existent_message='No service is running now. '
'Please spin up a service first.',
)

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend)

code = serve_utils.ServeCodeGen.terminate_replica(service_name, replica_id,
purge)
returncode, stdout, stderr = backend.run_on_head(handle,
code,
require_outputs=True,
stream_logs=False,
separate_stderr=True)

try:
subprocess_utils.handle_returncode(returncode,
code,
'Failed to terminate the replica',
stderr,
stream_logs=True)
except exceptions.CommandError as e:
raise RuntimeError(e.error_msg) from e

sky_logging.print(stdout)


@usage_lib.entrypoint
def status(
service_names: Optional[Union[str,
Expand Down
17 changes: 13 additions & 4 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ class ReplicaStatusProperty:
is_scale_down: bool = False
# The replica's spot instance was preempted.
preempted: bool = False
# Whether the replica is purged.
purged: bool = False

def remove_terminated_replica(self) -> bool:
"""Whether to remove the replica record from the replica table.
Expand Down Expand Up @@ -307,6 +309,8 @@ def should_track_service_status(self) -> bool:
return False
if self.preempted:
return False
if self.purged:
return False
return True

def to_replica_status(self) -> serve_state.ReplicaStatus:
Expand Down Expand Up @@ -590,7 +594,7 @@ def scale_up(self,
"""
raise NotImplementedError

def scale_down(self, replica_id: int) -> None:
def scale_down(self, replica_id: int, purge: bool = False) -> None:
"""Scale down replica with replica_id."""
raise NotImplementedError

Expand Down Expand Up @@ -679,7 +683,8 @@ def _terminate_replica(self,
replica_id: int,
sync_down_logs: bool,
replica_drain_delay_seconds: int,
is_scale_down: bool = False) -> None:
is_scale_down: bool = False,
purge: bool = False) -> None:

if replica_id in self._launch_process_pool:
info = serve_state.get_replica_info_from_id(self._service_name,
Expand Down Expand Up @@ -763,16 +768,18 @@ def _download_and_stream_logs(info: ReplicaInfo):
)
info.status_property.sky_down_status = ProcessStatus.RUNNING
info.status_property.is_scale_down = is_scale_down
info.status_property.purged = purge
serve_state.add_or_update_replica(self._service_name, replica_id, info)
p.start()
self._down_process_pool[replica_id] = p

def scale_down(self, replica_id: int) -> None:
def scale_down(self, replica_id: int, purge: bool = False) -> None:
self._terminate_replica(
replica_id,
sync_down_logs=False,
replica_drain_delay_seconds=_DEFAULT_DRAIN_SECONDS,
is_scale_down=True)
is_scale_down=True,
purge=purge)

def _handle_preemption(self, info: ReplicaInfo) -> bool:
"""Handle preemption of the replica if any error happened.
Expand Down Expand Up @@ -911,6 +918,8 @@ def _refresh_process_pool(self) -> None:
# since user should fixed the error before update.
elif info.version != self.latest_version:
removal_reason = 'for version outdated'
elif info.status_property.purged:
removal_reason = 'for purge'
else:
logger.info(f'Termination of replica {replica_id} '
'finished. Replica info is kept since some '
Expand Down
44 changes: 43 additions & 1 deletion sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,36 @@ def update_service_encoded(service_name: str, version: int, mode: str) -> str:
return common_utils.encode_payload(service_msg)


def terminate_replica(service_name: str, replica_id: int, purge: bool) -> str:
service_status = _get_service_status(service_name)
if service_status is None:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Service {service_name!r} does not exist.')
replica_info = serve_state.get_replica_info_from_id(service_name,
replica_id)
if replica_info is None:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Replica {replica_id} for service {service_name} does not '
'exist.')

controller_port = service_status['controller_port']
resp = requests.post(
_CONTROLLER_URL.format(CONTROLLER_PORT=controller_port) +
'/controller/terminate_replica',
json={
'replica_id': replica_id,
'purge': purge,
})

message: str = resp.json()['message']
if resp.status_code != 200:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Failed to terminate replica {replica_id} '
f'in {service_name}. Reason:\n{message}')
return message


def _get_service_status(
service_name: str,
with_replica_info: bool = True) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -735,7 +765,7 @@ def _get_replicas(service_record: Dict[str, Any]) -> str:


def get_endpoint(service_record: Dict[str, Any]) -> str:
# Don't use backend_utils.is_controller_up since it is too slow.
# Don't use backend_utils.is_controller_accessible since it is too slow.
handle = global_user_state.get_handle_from_cluster_name(
SKY_SERVE_CONTROLLER_NAME)
assert isinstance(handle, backends.CloudVmRayResourceHandle)
Expand Down Expand Up @@ -915,6 +945,18 @@ def terminate_services(cls, service_names: Optional[List[str]],
]
return cls._build(code)

@classmethod
def terminate_replica(cls, service_name: str, replica_id: int,
purge: bool) -> str:
code = [
f'(lambda: print(serve_utils.terminate_replica({service_name!r}, '
f'{replica_id}, {purge}), end="", flush=True) '
'if getattr(constants, "SERVE_VERSION", 0) >= 2 else '
f'exec("raise RuntimeError('
f'{constants.TERMINATE_REPLICA_VERSION_MISMATCH_ERROR!r})"))()'
]
return cls._build(code)

@classmethod
def wait_service_registration(cls, service_name: str, job_id: int) -> str:
code = [
Expand Down

0 comments on commit c6ae536

Please sign in to comment.