Skip to content

Commit

Permalink
[SkyServe] Serving with Spot (#2749)
Browse files Browse the repository at this point in the history
* rebase and fix bugs

* fix PR reviews

* fix

* fix comments

* rename tests

* fix yaml replica_num
  • Loading branch information
MaoZiming authored Nov 2, 2023
1 parent 8773d0c commit 2008590
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 9 deletions.
4 changes: 2 additions & 2 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, service_name: str, service_spec: serve.SkyServiceSpec,
task_yaml: str, port: int) -> None:
self.service_name = service_name
self.replica_manager: replica_managers.ReplicaManager = (
replica_managers.SkyPilotReplicaManager(service_name,
service_spec,
replica_managers.SkyPilotReplicaManager(service_name=service_name,
spec=service_spec,
task_yaml_path=task_yaml))
self.autoscaler: autoscalers.Autoscaler = (
autoscalers.RequestRateAutoscaler(
Expand Down
63 changes: 60 additions & 3 deletions sky/serve/replica_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sky import exceptions
from sky import global_user_state
from sky import sky_logging
from sky import status_lib
from sky.backends import backend_utils
from sky.serve import constants as serve_constants
from sky.serve import serve_state
Expand Down Expand Up @@ -169,6 +170,8 @@ class ReplicaStatusProperty:
first_ready_time: Optional[float] = None
# None means sky.down is not called yet.
sky_down_status: Optional[ProcessStatus] = None
# The replica's spot instance was preempted.
preempted: bool = False

def is_scale_down_succeeded(self, initial_delay_seconds: int,
auto_restart: bool) -> bool:
Expand All @@ -193,6 +196,8 @@ def is_scale_down_succeeded(self, initial_delay_seconds: int,
return True
if self.user_app_failed:
return False
if self.preempted:
return True
if not self.service_ready_now:
return False
return self.first_ready_time is not None
Expand All @@ -204,13 +209,18 @@ def should_track_status(self) -> bool:
return False
if self.user_app_failed:
return False
if self.preempted:
return True
return True

def to_replica_status(self) -> serve_state.ReplicaStatus:
if self.sky_launch_status == ProcessStatus.RUNNING:
# Still launching
return serve_state.ReplicaStatus.PROVISIONING
if self.sky_down_status is not None:
if self.preempted:
# Replica (spot) is preempted
return serve_state.ReplicaStatus.PREEMPTED
if self.sky_down_status == ProcessStatus.RUNNING:
# sky.down is running
return serve_state.ReplicaStatus.SHUTTING_DOWN
Expand Down Expand Up @@ -472,6 +482,8 @@ def _sync_down_logs():
info = serve_state.get_replica_info_from_id(self.service_name,
replica_id)
assert info is not None
logger.info(f'preempted: {info.status_property.preempted}, '
f'replica_id: {replica_id}')
log_file_name = serve_utils.generate_replica_down_log_file_name(
self.service_name, replica_id)
p = multiprocessing.Process(
Expand All @@ -490,6 +502,16 @@ def scale_down(self, replica_ids: List[int]) -> None:
for replica_id in replica_ids:
self._terminate_replica(replica_id, sync_down_logs=False)

def _recover_from_preemption(self, replica_id: int) -> None:
logger.info(f'Beginning recovery for preempted replica {replica_id}.')
# TODO(MaoZiming): Support spot recovery policies
info = serve_state.get_replica_info_from_id(self.service_name,
replica_id)
assert info is not None
info.status_property.preempted = True
serve_state.add_or_update_replica(self.service_name, replica_id, info)
self._terminate_replica(replica_id, sync_down_logs=False)

#################################
# ReplicaManager Daemon Threads #
#################################
Expand Down Expand Up @@ -557,11 +579,16 @@ def _refresh_process_pool(self) -> None:
if info.status_property.is_scale_down_succeeded(
self.initial_delay_seconds, self.auto_restart):
# This means the cluster is deleted due to
# a scale down. Delete the replica info
# a scale down or the cluster is recovering
# from preemption. Delete the replica info
# so it won't count as a replica.
logger.info(f'Replica {replica_id} removed from the '
'replica table normally.')
serve_state.remove_replica(self.service_name, replica_id)
if info.status_property.preempted:
removal_reason = 'for preemption recovery'
else:
removal_reason = 'normally'
logger.info(f'Replica {replica_id} removed from the '
f'replica table {removal_reason}.')
else:
logger.info(f'Termination of replica {replica_id} '
'finished. Replica info is kept since some '
Expand Down Expand Up @@ -673,6 +700,36 @@ def _probe_all_replicas(self) -> None:
if info.status_property.first_ready_time is None:
info.status_property.first_ready_time = probe_time
else:
handle = info.handle
if handle is None:
logger.error('Cannot find handle for '
f'replica {info.replica_id}.')
elif handle.launched_resources is None:
logger.error('Cannot find launched_resources in handle'
f' for replica {info.replica_id}.')
elif handle.launched_resources.use_spot:
# Pull the actual cluster status
# from the cloud provider to
# determine whether the cluster is preempted.
(cluster_status,
_) = backends.backend_utils.refresh_cluster_status_handle(
info.cluster_name,
force_refresh_statuses=set(status_lib.ClusterStatus))

if cluster_status != status_lib.ClusterStatus.UP:
# The cluster is (partially) preempted.
# It can be down, INIT or STOPPED, based on the
# interruption behavior of the cloud.
# Spot recovery is needed.
cluster_status_str = (
'' if cluster_status is None else
f' (status: {cluster_status.value})')
logger.info(f'Replica {info.replica_id} '
f'is preempted{cluster_status_str}.')
self._recover_from_preemption(info.replica_id)

continue

if info.first_not_ready_time is None:
info.first_not_ready_time = probe_time
if info.status_property.first_ready_time is not None:
Expand Down
4 changes: 4 additions & 0 deletions sky/serve/serve_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class ReplicaStatus(enum.Enum):
# we should guarantee no resource leakage like regular sky.
FAILED_CLEANUP = 'FAILED_CLEANUP'

# The replica is a spot VM and it is preempted by the cloud provider.
PREEMPTED = 'PREEMPTED'

# Unknown status. This should never happen.
UNKNOWN = 'UNKNOWN'

Expand All @@ -101,6 +104,7 @@ def colored_str(self) -> str:
ReplicaStatus.FAILED_CLEANUP: colorama.Fore.RED,
ReplicaStatus.SHUTTING_DOWN: colorama.Fore.MAGENTA,
ReplicaStatus.FAILED: colorama.Fore.RED,
ReplicaStatus.PREEMPTED: colorama.Fore.MAGENTA,
ReplicaStatus.UNKNOWN: colorama.Fore.RED,
}

Expand Down
8 changes: 4 additions & 4 deletions sky/utils/cli_utils/status_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,14 +491,14 @@ def _get_replica_resources(replica_record: _ReplicaRecord) -> str:
return '-'
assert isinstance(handle, backends.CloudVmRayResourceHandle)
cloud = handle.launched_resources.cloud
launched_resource_str = f'{cloud}'
if handle.launched_resources.accelerators is None:
vcpu, _ = cloud.get_vcpus_mem_from_instance_type(
handle.launched_resources.instance_type)
launched_resource_str += f'(vCPU={int(vcpu)})'
hardware = f'vCPU={int(vcpu)}'
else:
launched_resource_str += f'({handle.launched_resources.accelerators})'
resources_str = (f'{handle.launched_nodes}x {launched_resource_str}')
hardware = f'{handle.launched_resources.accelerators})'
spot = '[Spot]' if handle.launched_resources.use_spot else ''
resources_str = f'{handle.launched_nodes}x {cloud}({spot}{hardware})'
return resources_str


Expand Down
17 changes: 17 additions & 0 deletions tests/skyserve/spot/recovery.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
resources:
cloud: gcp
cpus: 2+
zone: us-central1-a
use_spot: true

workdir: examples/serve/http_server

# Use 8080 to test jupyter service is terminated
run: python3 server.py --port 8080

service:
port: 8080
readiness_probe:
path: /health
initial_delay_seconds: 20
replicas: 1
6 changes: 6 additions & 0 deletions tests/skyserve/spot/user_bug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import time

# The program exits to simulate a user app bug.
if __name__ == "__main__":
time.sleep(30)
assert False
16 changes: 16 additions & 0 deletions tests/skyserve/spot/user_bug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
resources:
cloud: gcp
cpus: 2+
zone: us-central1-a
use_spot: True

workdir: tests/skyserve/spot

run: python3 user_bug.py

service:
port: 8080
readiness_probe:
path: /health
initial_delay_seconds: 20
replicas: 1
56 changes: 56 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,62 @@ def generate_llm_test_command(prompt: str, expected_output: str) -> str:
run_one_test(test)


@pytest.mark.gcp
@pytest.mark.sky_serve
def test_skyserve_spot_recovery():
name = _get_service_name()
zone = 'us-central1-a'

# Reference: test_spot_recovery_gcp
def terminate_replica(replica_id: int) -> str:
cluster_name = serve.generate_replica_cluster_name(name, replica_id)
query_cmd = (f'gcloud compute instances list --filter='
f'"(labels.ray-cluster-name:{cluster_name})" '
f'--zones={zone} --format="value(name)"')
return (f'gcloud compute instances delete --zone={zone}'
f' --quiet $({query_cmd})')

test = Test(
f'test-skyserve-spot-recovery-gcp',
[
f'sky serve up -n {name} -y tests/skyserve/spot/recovery.yaml',
_SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
f'{_get_serve_endpoint(name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"',
terminate_replica(1),
_SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
f'{_get_serve_endpoint(name)}; curl -L http://$endpoint | grep "Hi, SkyPilot here"',
],
f'sky serve down -y {name}',
timeout=20 * 60,
)
run_one_test(test)


@pytest.mark.gcp
@pytest.mark.sky_serve
def test_skyserve_spot_user_bug():
"""Tests that spot recovery doesn't occur for non-preemption failures"""
name = _get_service_name()
test = Test(
f'test-skyserve-spot-user-bug-gcp',
[
f'sky serve up -n {name} -y tests/skyserve/spot/user_bug.yaml',
_SERVE_WAIT_UNTIL_READY.format(name=name, replica_num=1),
# After failure due to user bug, the service should fail instead of
# triggering spot recovery.
'(while true; do'
f' output=$(sky serve status {name});'
' echo "$output" | grep -q "FAILED" && break;'
' echo "$output" | grep -q "PROVISIONING" && exit 1;'
' sleep 10;'
f'done)',
],
f'sky serve down -y {name}',
timeout=20 * 60,
)
run_one_test(test)


@pytest.mark.gcp
@pytest.mark.sky_serve
def test_skyserve_replica_failure():
Expand Down

0 comments on commit 2008590

Please sign in to comment.