Skip to content

Commit

Permalink
[Core][Provisioner] Support open ports on RunPod (#3748)
Browse files Browse the repository at this point in the history
* init

* nits

* upd controller dependencies

* support runpod as controller

* add smoke test

* Apply suggestions from code review

Co-authored-by: Zhanghao Wu <[email protected]>

* upd enum

* comments for 8081

* comments

* comments

---------

Co-authored-by: Zhanghao Wu <[email protected]>
  • Loading branch information
cblmemo and Michaelvll authored Aug 9, 2024
1 parent 7f64d60 commit 1e6db6e
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 57 deletions.
2 changes: 1 addition & 1 deletion examples/serve/http_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def do_GET(self):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SkyServe HTTP Test Server')
parser.add_argument('--port', type=int, required=False, default=8081)
parser.add_argument('--port', type=int, required=False, default=8080)
args = parser.parse_args()

Handler = MyHttpRequestHandler
Expand Down
2 changes: 1 addition & 1 deletion examples/serve/http_server/task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ service:
replicas: 2

resources:
ports: 8081
ports: 8080
cpus: 2+

workdir: examples/serve/http_server
Expand Down
1 change: 1 addition & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,6 +2757,7 @@ def get_endpoints(cluster: str,
cluster_records = get_clusters(include_controller=True,
refresh=False,
cluster_names=[cluster])
assert len(cluster_records) == 1, cluster_records
cluster_record = cluster_records[0]
if (not skip_status_check and
cluster_record['status'] != status_lib.ClusterStatus.UP):
Expand Down
60 changes: 44 additions & 16 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,16 @@ def _retry_zones(
assert to_provision.region == region.name, (to_provision,
region)
num_nodes = handle.launched_nodes
# Some clouds, like RunPod, only support exposing ports during
# launch. For those clouds, we pass the ports to open in the
# `bulk_provision` to expose the ports during provisioning.
# If the `bulk_provision` is to apply on an existing cluster,
# it should be ignored by the underlying provisioner impl
# as it will only apply to newly-created instances.
ports_to_open_on_launch = (
list(resources_utils.port_ranges_to_set(to_provision.ports))
if to_provision.cloud.OPEN_PORTS_VERSION <=
clouds.OpenPortsVersion.LAUNCH_ONLY else None)
try:
provision_record = provisioner.bulk_provision(
to_provision.cloud,
Expand All @@ -1479,7 +1489,8 @@ def _retry_zones(
num_nodes=num_nodes,
cluster_yaml=handle.cluster_yaml,
prev_cluster_ever_up=prev_cluster_ever_up,
log_dir=self.log_dir)
log_dir=self.log_dir,
ports_to_open_on_launch=ports_to_open_on_launch)
# NOTE: We will handle the logic of '_ensure_cluster_ray_started' #pylint: disable=line-too-long
# in 'provision_utils.post_provision_runtime_setup()' in the
# caller.
Expand Down Expand Up @@ -1937,8 +1948,9 @@ def provision_with_retries(
cloud_user = to_provision.cloud.get_current_user_identity()

requested_features = self._requested_features.copy()
# Skip stop feature for Kubernetes controllers.
if (isinstance(to_provision.cloud, clouds.Kubernetes) and
# Skip stop feature for Kubernetes and RunPod controllers.
if (isinstance(to_provision.cloud,
(clouds.Kubernetes, clouds.RunPod)) and
controller_utils.Controllers.from_name(cluster_name)
is not None):
assert (clouds.CloudImplementationFeatures.STOP
Expand Down Expand Up @@ -2975,9 +2987,12 @@ def _update_after_cluster_provisioned(
resources_utils.port_ranges_to_set(current_ports) -
resources_utils.port_ranges_to_set(prev_ports))
if open_new_ports:
with rich_utils.safe_status(
'[bold cyan]Launching - Opening new ports'):
self._open_ports(handle)
cloud = handle.launched_resources.cloud
if not (cloud.OPEN_PORTS_VERSION <=
clouds.OpenPortsVersion.LAUNCH_ONLY):
with rich_utils.safe_status(
'[bold cyan]Launching - Opening new ports'):
self._open_ports(handle)

with timeline.Event('backend.provision.post_process'):
global_user_state.add_or_update_cluster(
Expand Down Expand Up @@ -4083,15 +4098,16 @@ def set_autostop(self,
# The core.autostop() function should have already checked that the
# cloud and resources support requested autostop.
if idle_minutes_to_autostop is not None:
# Skip auto-stop for Kubernetes clusters.
if (isinstance(handle.launched_resources.cloud, clouds.Kubernetes)
and not down and idle_minutes_to_autostop >= 0):
# Skip auto-stop for Kubernetes and RunPod clusters.
if (isinstance(handle.launched_resources.cloud,
(clouds.Kubernetes, clouds.RunPod)) and not down and
idle_minutes_to_autostop >= 0):
# We should hit this code path only for the controllers on
# Kubernetes clusters.
# Kubernetes and RunPod clusters.
assert (controller_utils.Controllers.from_name(
handle.cluster_name) is not None), handle.cluster_name
logger.info('Auto-stop is not supported for Kubernetes '
'clusters. Skipping.')
'and RunPod clusters. Skipping.')
return

# Check if we're stopping spot
Expand Down Expand Up @@ -4274,12 +4290,24 @@ def _check_existing_cluster(
# Assume resources share the same ports.
for resource in task.resources:
assert resource.ports == list(task.resources)[0].ports
all_ports = resources_utils.port_set_to_ranges(
resources_utils.port_ranges_to_set(
handle.launched_resources.ports) |
resources_utils.port_ranges_to_set(
list(task.resources)[0].ports))
requested_ports_set = resources_utils.port_ranges_to_set(
list(task.resources)[0].ports)
current_ports_set = resources_utils.port_ranges_to_set(
handle.launched_resources.ports)
all_ports = resources_utils.port_set_to_ranges(current_ports_set |
requested_ports_set)
to_provision = handle.launched_resources
if (to_provision.cloud.OPEN_PORTS_VERSION <=
clouds.OpenPortsVersion.LAUNCH_ONLY):
if not requested_ports_set <= current_ports_set:
current_cloud = to_provision.cloud
with ux_utils.print_exception_no_traceback():
raise exceptions.NotSupportedError(
'Failed to open new ports on an existing cluster '
f'with the current cloud {current_cloud} as it only'
' supports opening ports on launch of the cluster. '
'Please terminate the existing cluster and launch '
'a new cluster with the desired ports open.')
if all_ports:
to_provision = to_provision.copy(ports=all_ports)
return RetryingVmProvisioner.ToProvisionConfig(
Expand Down
1 change: 1 addition & 0 deletions sky/clouds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sky.clouds.cloud import Cloud
from sky.clouds.cloud import cloud_in_iterable
from sky.clouds.cloud import CloudImplementationFeatures
from sky.clouds.cloud import OpenPortsVersion
from sky.clouds.cloud import ProvisionerVersion
from sky.clouds.cloud import Region
from sky.clouds.cloud import StatusVersion
Expand Down
20 changes: 20 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ def __ge__(self, other):
return self.value >= other.value


class OpenPortsVersion(enum.Enum):
"""The version of the open ports implementation.
1: Open ports on launching of the cluster only, cannot be modified after
provisioning of the cluster. This is for clouds like RunPod which only
accepts port argument on VM creation API, and requires Web GUI and an VM
restart to update ports. We currently do not support this.
2: Open ports after provisioning of the cluster, updatable. This is for most
of the cloud providers which allow opening ports using an programmable API
and won't affect the running VMs.
"""
LAUNCH_ONLY = 'LAUNCH ONLY'
UPDATABLE = 'UPDATABLE'

def __le__(self, other):
versions = list(OpenPortsVersion)
return versions.index(self) <= versions.index(other)


class Cloud:
"""A cloud provider."""

Expand All @@ -107,6 +126,7 @@ class Cloud:
# NOTE: new clouds being added should use the latest version, i.e. SKYPILOT.
PROVISIONER_VERSION = ProvisionerVersion.RAY_AUTOSCALER
STATUS_VERSION = StatusVersion.CLOUD_CLI
OPEN_PORTS_VERSION = OpenPortsVersion.UPDATABLE

@classmethod
def max_cluster_name_length(cls) -> Optional[int]:
Expand Down
6 changes: 1 addition & 5 deletions sky/clouds/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ class RunPod(clouds.Cloud):
clouds.CloudImplementationFeatures.MULTI_NODE:
('Multi-node not supported yet, as the interconnection among nodes '
'are non-trivial on RunPod.'),
clouds.CloudImplementationFeatures.OPEN_PORTS:
('Opening ports is not '
'supported yet on RunPod.'),
clouds.CloudImplementationFeatures.IMAGE_ID:
('Specifying image ID is not supported on RunPod.'),
clouds.CloudImplementationFeatures.DOCKER_IMAGE:
Expand All @@ -43,14 +40,13 @@ class RunPod(clouds.Cloud):
('Mounting object stores is not supported on RunPod. To read data '
'from object stores on RunPod, use `mode: COPY` to copy the data '
'to local disk.'),
clouds.CloudImplementationFeatures.HOST_CONTROLLERS:
('Host controllers are not supported on RunPod.'),
}
_MAX_CLUSTER_NAME_LEN_LIMIT = 120
_regions: List[clouds.Region] = []

PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
STATUS_VERSION = clouds.StatusVersion.SKYPILOT
OPEN_PORTS_VERSION = clouds.OpenPortsVersion.LAUNCH_ONLY

@classmethod
def _unsupported_features_for_resources(
Expand Down
2 changes: 2 additions & 0 deletions sky/provision/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ProvisionConfig:
tags: Dict[str, str]
# Whether or not to resume stopped instances.
resume_stopped_nodes: bool
# Optional ports to open on launch of the cluster.
ports_to_open_on_launch: Optional[List[int]]


# -------------------- output data model -------------------- #
Expand Down
4 changes: 3 additions & 1 deletion sky/provision/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def bulk_provision(
cluster_yaml: str,
prev_cluster_ever_up: bool,
log_dir: str,
ports_to_open_on_launch: Optional[List[int]] = None,
) -> provision_common.ProvisionRecord:
"""Provisions a cluster and wait until fully provisioned.
Expand All @@ -150,7 +151,8 @@ def bulk_provision(
['node_config'],
count=num_nodes,
tags={},
resume_stopped_nodes=True)
resume_stopped_nodes=True,
ports_to_open_on_launch=ports_to_open_on_launch)

with provision_logging.setup_provision_logging(log_dir):
try:
Expand Down
1 change: 1 addition & 0 deletions sky/provision/runpod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sky.provision.runpod.instance import cleanup_ports
from sky.provision.runpod.instance import get_cluster_info
from sky.provision.runpod.instance import query_instances
from sky.provision.runpod.instance import query_ports
from sky.provision.runpod.instance import run_instances
from sky.provision.runpod.instance import stop_instances
from sky.provision.runpod.instance import terminate_instances
Expand Down
34 changes: 28 additions & 6 deletions sky/provision/runpod/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sky.provision import common
from sky.provision.runpod import utils
from sky.utils import common_utils
from sky.utils import resources_utils
from sky.utils import ux_utils

POLL_INTERVAL = 5
Expand All @@ -15,12 +16,13 @@


def _filter_instances(cluster_name_on_cloud: str,
status_filters: Optional[List[str]]) -> Dict[str, Any]:
status_filters: Optional[List[str]],
head_only: bool = False) -> Dict[str, Any]:

instances = utils.list_instances()
possible_names = [
f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker'
]
possible_names = [f'{cluster_name_on_cloud}-head']
if not head_only:
possible_names.append(f'{cluster_name_on_cloud}-worker')

filtered_instances = {}
for instance_id, instance in instances.items():
Expand Down Expand Up @@ -83,7 +85,8 @@ def run_instances(region: str, cluster_name_on_cloud: str,
name=f'{cluster_name_on_cloud}-{node_type}',
instance_type=config.node_config['InstanceType'],
region=region,
disk_size=config.node_config['DiskSize'])
disk_size=config.node_config['DiskSize'],
ports=config.ports_to_open_on_launch)
except Exception as e: # pylint: disable=broad-except
logger.warning(f'run_instances error: {e}')
raise
Expand Down Expand Up @@ -205,6 +208,25 @@ def query_instances(

def cleanup_ports(
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
del cluster_name_on_cloud, provider_config
del cluster_name_on_cloud, ports, provider_config # Unused.


def query_ports(
cluster_name_on_cloud: str,
ports: List[str],
head_ip: Optional[str] = None,
provider_config: Optional[Dict[str, Any]] = None,
) -> Dict[int, List[common.Endpoint]]:
"""See sky/provision/__init__.py"""
del head_ip, provider_config # Unused.
instances = _filter_instances(cluster_name_on_cloud, None, head_only=True)
assert len(instances) == 1
head_inst = list(instances.values())[0]
return {
port: [common.SocketEndpoint(**endpoint)]
for port, endpoint in head_inst['port2endpoint'].items()
if port in resources_utils.port_ranges_to_set(ports)
}
25 changes: 19 additions & 6 deletions sky/provision/runpod/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""RunPod library wrapper for SkyPilot."""

import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from sky import sky_logging
from sky.adaptors import runpod
Expand Down Expand Up @@ -74,21 +74,28 @@ def list_instances() -> Dict[str, Dict[str, Any]]:

info['status'] = instance['desiredStatus']
info['name'] = instance['name']
info['port2endpoint'] = {}

if instance['desiredStatus'] == 'RUNNING' and instance.get('runtime'):
for port in instance['runtime']['ports']:
if port['privatePort'] == 22 and port['isIpPublic']:
info['external_ip'] = port['ip']
info['ssh_port'] = port['publicPort']
elif not port['isIpPublic']:
if port['isIpPublic']:
if port['privatePort'] == 22:
info['external_ip'] = port['ip']
info['ssh_port'] = port['publicPort']
info['port2endpoint'][port['privatePort']] = {
'host': port['ip'],
'port': port['publicPort']
}
else:
info['internal_ip'] = port['ip']

instance_dict[instance['id']] = info

return instance_dict


def launch(name: str, instance_type: str, region: str, disk_size: int) -> str:
def launch(name: str, instance_type: str, region: str, disk_size: int,
ports: Optional[List[int]]) -> str:
"""Launches an instance with the given parameters.
Converts the instance_type to the RunPod GPU name, finds the specs for the
Expand All @@ -100,6 +107,11 @@ def launch(name: str, instance_type: str, region: str, disk_size: int) -> str:

gpu_specs = runpod.runpod.get_gpu(gpu_type)

# Port 8081 is occupied for nginx in the base image.
custom_ports_str = ''
if ports is not None:
custom_ports_str = ''.join([f'{p}/tcp,' for p in ports])

new_instance = runpod.runpod.create_pod(
name=name,
image_name='runpod/base:0.0.2',
Expand All @@ -111,6 +123,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int) -> str:
gpu_count=gpu_quantity,
country_code=region,
ports=(f'22/tcp,'
f'{custom_ports_str}'
f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,'
f'{constants.SKY_REMOTE_RAY_PORT}/http'),
support_public_ip=True,
Expand Down
17 changes: 3 additions & 14 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Responsible for autoscaling and replica management.
"""
import logging
import os
import threading
import time
import traceback
Expand Down Expand Up @@ -160,17 +159,7 @@ def configure_logger():
# TODO(tian): Probably we should support service that will stop the VM in
# specific time period.
def run_controller(service_name: str, service_spec: serve.SkyServiceSpec,
task_yaml: str, controller_port: int):
# We expose the controller to the public network when running inside a
# kubernetes cluster to allow external load balancers (example, for
# high availability load balancers) to communicate with the controller.
def _get_host():
if 'KUBERNETES_SERVICE_HOST' in os.environ:
return '0.0.0.0'
else:
return 'localhost'

host = _get_host()
controller = SkyServeController(service_name, service_spec, task_yaml, host,
controller_port)
task_yaml: str, controller_host: str, controller_port: int):
controller = SkyServeController(service_name, service_spec, task_yaml,
controller_host, controller_port)
controller.run()
Loading

0 comments on commit 1e6db6e

Please sign in to comment.