Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Provisioner] Support open ports on RunPod #3748

Merged
merged 13 commits into from
Aug 9, 2024
2 changes: 1 addition & 1 deletion examples/serve/http_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def do_GET(self):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SkyServe HTTP Test Server')
parser.add_argument('--port', type=int, required=False, default=8081)
parser.add_argument('--port', type=int, required=False, default=8080)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we change this port from 8081 to 8080?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunPod default image has an nginx running at 8081 on launch:

(base) root@2c7c82d210d5:~# lsof -i :8081
COMMAND PID USER   FD   TYPE     DEVICE SIZE/OFF NODE NAME
nginx    40 root    9u  IPv4 3927576185      0t0  TCP *:8081 (LISTEN)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment somewhere (probably in the RunPod provisioned codebase) mentioning that the 8081 port is being taken by RunPod image by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Thanks!

args = parser.parse_args()

Handler = MyHttpRequestHandler
Expand Down
2 changes: 1 addition & 1 deletion examples/serve/http_server/task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ service:
replicas: 2

resources:
ports: 8081
ports: 8080
cpus: 2+

workdir: examples/serve/http_server
Expand Down
1 change: 1 addition & 0 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2757,6 +2757,7 @@ def get_endpoints(cluster: str,
cluster_records = get_clusters(include_controller=True,
refresh=False,
cluster_names=[cluster])
assert len(cluster_records) == 1, cluster_records
cluster_record = cluster_records[0]
if (not skip_status_check and
cluster_record['status'] != status_lib.ClusterStatus.UP):
Expand Down
60 changes: 44 additions & 16 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,16 @@ def _retry_zones(
assert to_provision.region == region.name, (to_provision,
region)
num_nodes = handle.launched_nodes
# Some clouds, like RunPod, only support exposing ports during
# launch. For those clouds, we pass the ports to open in the
# `bulk_provision` to expose the ports during provisioning.
# If the `bulk_provision` is to apply on an existing cluster,
# it should be ignored by the underlying provisioner impl
# as it will only apply to newly-created instances.
ports_to_open_on_launch = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ports_to_open_on_launch = (
# Some clouds, like RunPod, only support exposing ports during launch. For those
# clouds, we pass the ports to open in the `bulk_provision` to expose the ports
# during provisioning.
# If the `bulk_provision` is to apply on an existing cluster, it should be ignored by
# the underlying provisioner implementation.
ports_to_open_on_launch = (

Please double check the above comment is true. We do ignore the ports to open when it is on an existing cluster, right?

Copy link
Collaborator Author

@cblmemo cblmemo Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Currently it will raise an error on runpod, if there are any new ports to be opened. e.g. provision with 8000-8010 and launch with 8002,8009 will not raise an error but 8005-8015 will. Do you think we should change it to a warning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After an offline discussion, we've decided to raise the error here. Merging now 🫡

list(resources_utils.port_ranges_to_set(to_provision.ports))
if to_provision.cloud.OPEN_PORTS_VERSION <=
clouds.OpenPortsVersion.LAUNCH_ONLY else None)
try:
provision_record = provisioner.bulk_provision(
to_provision.cloud,
Expand All @@ -1479,7 +1489,8 @@ def _retry_zones(
num_nodes=num_nodes,
cluster_yaml=handle.cluster_yaml,
prev_cluster_ever_up=prev_cluster_ever_up,
log_dir=self.log_dir)
log_dir=self.log_dir,
ports_to_open_on_launch=ports_to_open_on_launch)
# NOTE: We will handle the logic of '_ensure_cluster_ray_started' #pylint: disable=line-too-long
# in 'provision_utils.post_provision_runtime_setup()' in the
# caller.
Expand Down Expand Up @@ -1937,8 +1948,9 @@ def provision_with_retries(
cloud_user = to_provision.cloud.get_current_user_identity()

requested_features = self._requested_features.copy()
# Skip stop feature for Kubernetes controllers.
if (isinstance(to_provision.cloud, clouds.Kubernetes) and
# Skip stop feature for Kubernetes and RunPod controllers.
if (isinstance(to_provision.cloud,
(clouds.Kubernetes, clouds.RunPod)) and
controller_utils.Controllers.from_name(cluster_name)
is not None):
assert (clouds.CloudImplementationFeatures.STOP
Expand Down Expand Up @@ -2975,9 +2987,12 @@ def _update_after_cluster_provisioned(
resources_utils.port_ranges_to_set(current_ports) -
resources_utils.port_ranges_to_set(prev_ports))
if open_new_ports:
with rich_utils.safe_status(
'[bold cyan]Launching - Opening new ports'):
self._open_ports(handle)
cloud = handle.launched_resources.cloud
if not (cloud.OPEN_PORTS_VERSION <=
clouds.OpenPortsVersion.LAUNCH_ONLY):
with rich_utils.safe_status(
'[bold cyan]Launching - Opening new ports'):
self._open_ports(handle)

with timeline.Event('backend.provision.post_process'):
global_user_state.add_or_update_cluster(
Expand Down Expand Up @@ -4083,15 +4098,16 @@ def set_autostop(self,
# The core.autostop() function should have already checked that the
# cloud and resources support requested autostop.
if idle_minutes_to_autostop is not None:
# Skip auto-stop for Kubernetes clusters.
if (isinstance(handle.launched_resources.cloud, clouds.Kubernetes)
and not down and idle_minutes_to_autostop >= 0):
# Skip auto-stop for Kubernetes and RunPod clusters.
if (isinstance(handle.launched_resources.cloud,
(clouds.Kubernetes, clouds.RunPod)) and not down and
idle_minutes_to_autostop >= 0):
# We should hit this code path only for the controllers on
# Kubernetes clusters.
# Kubernetes and RunPod clusters.
assert (controller_utils.Controllers.from_name(
handle.cluster_name) is not None), handle.cluster_name
logger.info('Auto-stop is not supported for Kubernetes '
'clusters. Skipping.')
'and RunPod clusters. Skipping.')
return

# Check if we're stopping spot
Expand Down Expand Up @@ -4274,12 +4290,24 @@ def _check_existing_cluster(
# Assume resources share the same ports.
for resource in task.resources:
assert resource.ports == list(task.resources)[0].ports
all_ports = resources_utils.port_set_to_ranges(
resources_utils.port_ranges_to_set(
handle.launched_resources.ports) |
resources_utils.port_ranges_to_set(
list(task.resources)[0].ports))
requested_ports_set = resources_utils.port_ranges_to_set(
list(task.resources)[0].ports)
current_ports_set = resources_utils.port_ranges_to_set(
handle.launched_resources.ports)
all_ports = resources_utils.port_set_to_ranges(current_ports_set |
requested_ports_set)
to_provision = handle.launched_resources
if (to_provision.cloud.OPEN_PORTS_VERSION <=
clouds.OpenPortsVersion.LAUNCH_ONLY):
if not requested_ports_set <= current_ports_set:
current_cloud = to_provision.cloud
with ux_utils.print_exception_no_traceback():
raise exceptions.NotSupportedError(
'Failed to open new ports on an existing cluster '
f'with the current cloud {current_cloud} as it only'
' supports opening ports on launch of the cluster. '
'Please terminate the existing cluster and launch '
'a new cluster with the desired ports open.')
if all_ports:
to_provision = to_provision.copy(ports=all_ports)
return RetryingVmProvisioner.ToProvisionConfig(
Expand Down
1 change: 1 addition & 0 deletions sky/clouds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sky.clouds.cloud import Cloud
from sky.clouds.cloud import cloud_in_iterable
from sky.clouds.cloud import CloudImplementationFeatures
from sky.clouds.cloud import OpenPortsVersion
from sky.clouds.cloud import ProvisionerVersion
from sky.clouds.cloud import Region
from sky.clouds.cloud import StatusVersion
Expand Down
20 changes: 20 additions & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ def __ge__(self, other):
return self.value >= other.value


class OpenPortsVersion(enum.Enum):
"""The version of the open ports implementation.

1: Open ports on launching of the cluster only, cannot be modified after
provisioning of the cluster. This is for clouds like RunPod which only
accepts port argument on VM creation API, and requires Web GUI and an VM
restart to update ports. We currently do not support this.
2: Open ports after provisioning of the cluster, updatable. This is for most
of the cloud providers which allow opening ports using an programmable API
and won't affect the running VMs.
"""
LAUNCH_ONLY = 'LAUNCH ONLY'
UPDATABLE = 'UPDATABLE'

def __le__(self, other):
versions = list(OpenPortsVersion)
return versions.index(self) <= versions.index(other)


class Cloud:
"""A cloud provider."""

Expand All @@ -107,6 +126,7 @@ class Cloud:
# NOTE: new clouds being added should use the latest version, i.e. SKYPILOT.
PROVISIONER_VERSION = ProvisionerVersion.RAY_AUTOSCALER
STATUS_VERSION = StatusVersion.CLOUD_CLI
OPEN_PORTS_VERSION = OpenPortsVersion.UPDATABLE

@classmethod
def max_cluster_name_length(cls) -> Optional[int]:
Expand Down
6 changes: 1 addition & 5 deletions sky/clouds/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ class RunPod(clouds.Cloud):
clouds.CloudImplementationFeatures.MULTI_NODE:
('Multi-node not supported yet, as the interconnection among nodes '
'are non-trivial on RunPod.'),
clouds.CloudImplementationFeatures.OPEN_PORTS:
('Opening ports is not '
'supported yet on RunPod.'),
clouds.CloudImplementationFeatures.IMAGE_ID:
('Specifying image ID is not supported on RunPod.'),
clouds.CloudImplementationFeatures.DOCKER_IMAGE:
Expand All @@ -43,14 +40,13 @@ class RunPod(clouds.Cloud):
('Mounting object stores is not supported on RunPod. To read data '
'from object stores on RunPod, use `mode: COPY` to copy the data '
'to local disk.'),
clouds.CloudImplementationFeatures.HOST_CONTROLLERS:
('Host controllers are not supported on RunPod.'),
}
_MAX_CLUSTER_NAME_LEN_LIMIT = 120
_regions: List[clouds.Region] = []

PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
STATUS_VERSION = clouds.StatusVersion.SKYPILOT
OPEN_PORTS_VERSION = clouds.OpenPortsVersion.LAUNCH_ONLY

@classmethod
def _unsupported_features_for_resources(
Expand Down
2 changes: 2 additions & 0 deletions sky/provision/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ProvisionConfig:
tags: Dict[str, str]
# Whether or not to resume stopped instances.
resume_stopped_nodes: bool
# Optional ports to open on launch of the cluster.
ports_to_open_on_launch: Optional[List[int]]


# -------------------- output data model -------------------- #
Expand Down
4 changes: 3 additions & 1 deletion sky/provision/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def bulk_provision(
cluster_yaml: str,
prev_cluster_ever_up: bool,
log_dir: str,
ports_to_open_on_launch: Optional[List[int]] = None,
) -> provision_common.ProvisionRecord:
"""Provisions a cluster and wait until fully provisioned.

Expand All @@ -150,7 +151,8 @@ def bulk_provision(
['node_config'],
count=num_nodes,
tags={},
resume_stopped_nodes=True)
resume_stopped_nodes=True,
ports_to_open_on_launch=ports_to_open_on_launch)

with provision_logging.setup_provision_logging(log_dir):
try:
Expand Down
1 change: 1 addition & 0 deletions sky/provision/runpod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sky.provision.runpod.instance import cleanup_ports
from sky.provision.runpod.instance import get_cluster_info
from sky.provision.runpod.instance import query_instances
from sky.provision.runpod.instance import query_ports
from sky.provision.runpod.instance import run_instances
from sky.provision.runpod.instance import stop_instances
from sky.provision.runpod.instance import terminate_instances
Expand Down
34 changes: 28 additions & 6 deletions sky/provision/runpod/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sky.provision import common
from sky.provision.runpod import utils
from sky.utils import common_utils
from sky.utils import resources_utils
from sky.utils import ux_utils

POLL_INTERVAL = 5
Expand All @@ -15,12 +16,13 @@


def _filter_instances(cluster_name_on_cloud: str,
status_filters: Optional[List[str]]) -> Dict[str, Any]:
status_filters: Optional[List[str]],
head_only: bool = False) -> Dict[str, Any]:

instances = utils.list_instances()
possible_names = [
f'{cluster_name_on_cloud}-head', f'{cluster_name_on_cloud}-worker'
]
possible_names = [f'{cluster_name_on_cloud}-head']
if not head_only:
possible_names.append(f'{cluster_name_on_cloud}-worker')

filtered_instances = {}
for instance_id, instance in instances.items():
Expand Down Expand Up @@ -83,7 +85,8 @@ def run_instances(region: str, cluster_name_on_cloud: str,
name=f'{cluster_name_on_cloud}-{node_type}',
instance_type=config.node_config['InstanceType'],
region=region,
disk_size=config.node_config['DiskSize'])
disk_size=config.node_config['DiskSize'],
ports=config.ports_to_open_on_launch)
except Exception as e: # pylint: disable=broad-except
logger.warning(f'run_instances error: {e}')
raise
Expand Down Expand Up @@ -205,6 +208,25 @@ def query_instances(

def cleanup_ports(
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
del cluster_name_on_cloud, provider_config
del cluster_name_on_cloud, ports, provider_config # Unused.


def query_ports(
cluster_name_on_cloud: str,
ports: List[str],
head_ip: Optional[str] = None,
provider_config: Optional[Dict[str, Any]] = None,
) -> Dict[int, List[common.Endpoint]]:
"""See sky/provision/__init__.py"""
del head_ip, provider_config # Unused.
instances = _filter_instances(cluster_name_on_cloud, None, head_only=True)
assert len(instances) == 1
head_inst = list(instances.values())[0]
return {
port: [common.SocketEndpoint(**endpoint)]
for port, endpoint in head_inst['port2endpoint'].items()
if port in resources_utils.port_ranges_to_set(ports)
}
25 changes: 19 additions & 6 deletions sky/provision/runpod/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""RunPod library wrapper for SkyPilot."""

import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from sky import sky_logging
from sky.adaptors import runpod
Expand Down Expand Up @@ -74,21 +74,28 @@ def list_instances() -> Dict[str, Dict[str, Any]]:

info['status'] = instance['desiredStatus']
info['name'] = instance['name']
info['port2endpoint'] = {}

if instance['desiredStatus'] == 'RUNNING' and instance.get('runtime'):
for port in instance['runtime']['ports']:
if port['privatePort'] == 22 and port['isIpPublic']:
info['external_ip'] = port['ip']
info['ssh_port'] = port['publicPort']
elif not port['isIpPublic']:
if port['isIpPublic']:
if port['privatePort'] == 22:
info['external_ip'] = port['ip']
info['ssh_port'] = port['publicPort']
info['port2endpoint'][port['privatePort']] = {
'host': port['ip'],
'port': port['publicPort']
}
else:
info['internal_ip'] = port['ip']

instance_dict[instance['id']] = info

return instance_dict


def launch(name: str, instance_type: str, region: str, disk_size: int) -> str:
def launch(name: str, instance_type: str, region: str, disk_size: int,
ports: Optional[List[int]]) -> str:
"""Launches an instance with the given parameters.

Converts the instance_type to the RunPod GPU name, finds the specs for the
Expand All @@ -100,6 +107,11 @@ def launch(name: str, instance_type: str, region: str, disk_size: int) -> str:

gpu_specs = runpod.runpod.get_gpu(gpu_type)

# Port 8081 is occupied for nginx in the base image.
custom_ports_str = ''
if ports is not None:
custom_ports_str = ''.join([f'{p}/tcp,' for p in ports])

new_instance = runpod.runpod.create_pod(
name=name,
image_name='runpod/base:0.0.2',
Expand All @@ -111,6 +123,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int) -> str:
gpu_count=gpu_quantity,
country_code=region,
ports=(f'22/tcp,'
f'{custom_ports_str}'
f'{constants.SKY_REMOTE_RAY_DASHBOARD_PORT}/http,'
f'{constants.SKY_REMOTE_RAY_PORT}/http'),
support_public_ip=True,
Expand Down
17 changes: 3 additions & 14 deletions sky/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Responsible for autoscaling and replica management.
"""
import logging
import os
import threading
import time
import traceback
Expand Down Expand Up @@ -160,17 +159,7 @@ def configure_logger():
# TODO(tian): Probably we should support service that will stop the VM in
# specific time period.
def run_controller(service_name: str, service_spec: serve.SkyServiceSpec,
task_yaml: str, controller_port: int):
# We expose the controller to the public network when running inside a
# kubernetes cluster to allow external load balancers (example, for
# high availability load balancers) to communicate with the controller.
def _get_host():
if 'KUBERNETES_SERVICE_HOST' in os.environ:
return '0.0.0.0'
else:
return 'localhost'

host = _get_host()
controller = SkyServeController(service_name, service_spec, task_yaml, host,
controller_port)
task_yaml: str, controller_host: str, controller_port: int):
controller = SkyServeController(service_name, service_spec, task_yaml,
controller_host, controller_port)
controller.run()
Loading
Loading