diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 5ee0aba3d20..dec127c768f 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -869,7 +869,7 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]): def write_cluster_config( to_provision: 'resources.Resources', num_nodes: int, - ports: Optional[List[Union[int, str]]], + ports: Optional[List[str]], cluster_config_template: str, cluster_name: str, local_wheel_path: pathlib.Path, diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index a2bfa3c5a3f..4d30246c33d 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -625,7 +625,7 @@ def __init__( cluster_name: str, resources: resources_lib.Resources, num_nodes: int, - ports_to_open: Optional[List[Union[int, str]]] = None, + ports_to_open: Optional[List[str]] = None, prev_cluster_status: Optional[status_lib.ClusterStatus] = None, prev_handle: Optional['CloudVmRayResourceHandle'] = None, ) -> None: @@ -2875,7 +2875,7 @@ def _get_zone(runner): return handle def _open_inexistent_ports(self, handle: CloudVmRayResourceHandle, - ports_to_open: List[Union[int, str]]) -> None: + ports_to_open: List[str]) -> None: cloud = handle.launched_resources.cloud if not isinstance(cloud, (clouds.AWS, clouds.GCP, clouds.Azure)): logger.warning(f'Cannot open ports for {cloud} that not support ' @@ -2891,8 +2891,7 @@ def _update_after_cluster_provisioned( self, handle: CloudVmRayResourceHandle, task: task_lib.Task, prev_cluster_status: Optional[status_lib.ClusterStatus], ip_list: List[str], ssh_port_list: List[int], - ports_to_open: Optional[List[Union[int, - str]]], lock_path: str) -> None: + ports_to_open: Optional[List[str]], lock_path: str) -> None: usage_lib.messages.usage.update_cluster_resources( handle.launched_nodes, handle.launched_resources) usage_lib.messages.usage.update_final_cluster_status( diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index aa0bf7ef259..367de795724 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -80,7 +80,7 @@ def terminate_instances( def open_ports( provider_name: str, cluster_name_on_cloud: str, - ports: List[Union[int, str]], + ports: List[str], provider_config: Optional[Dict[str, Any]] = None, ) -> None: """Open ports for inbound traffic.""" diff --git a/sky/provision/aws/instance.py b/sky/provision/aws/instance.py index 77c7c93b6d5..33a7a9d0c12 100644 --- a/sky/provision/aws/instance.py +++ b/sky/provision/aws/instance.py @@ -1,7 +1,7 @@ """AWS instance provisioning.""" import re import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from botocore import config @@ -227,7 +227,7 @@ def _maybe_move_to_new_sg( def open_ports( cluster_name_on_cloud: str, - ports: List[Union[int, str]], + ports: List[str], provider_config: Optional[Dict[str, Any]] = None, ) -> None: """See sky/provision/__init__.py""" @@ -242,15 +242,13 @@ def open_ports( ip_permissions = [] for port in ports: - if isinstance(port, int): + if port.isdigit(): from_port = to_port = port else: - from_to_port = port.split('-') - from_port = int(from_to_port[0]) - to_port = int(from_to_port[1]) + from_port, to_port = port.split('-') ip_permissions.append({ - 'FromPort': from_port, - 'ToPort': to_port, + 'FromPort': int(from_port), + 'ToPort': int(to_port), 'IpProtocol': 'tcp', 'IpRanges': [{ 'CidrIp': '0.0.0.0/0' diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 32f8bde2260..21418821452 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -1,5 +1,5 @@ """Azure instance provisioning.""" -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional from sky import sky_logging from sky.adaptors import azure @@ -29,7 +29,7 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable: def open_ports( cluster_name_on_cloud: str, - ports: List[Union[int, str]], + ports: List[str], provider_config: Optional[Dict[str, Any]] = None, ) -> None: """See sky/provision/__init__.py""" @@ -39,7 +39,6 @@ def open_ports( network_client = azure.get_client('network', subscription_id) create_or_update = get_azure_sdk_function( client=network_client.security_rules, function_name='create_or_update') - ports = [str(port) for port in ports if port != 22] rule_name = f'user-ports-{"-".join(ports)}' def security_rule_parameters(priority: int) -> Dict[str, Any]: diff --git a/sky/provision/gcp/instance.py b/sky/provision/gcp/instance.py index d3f4d392e8f..da562130e68 100644 --- a/sky/provision/gcp/instance.py +++ b/sky/provision/gcp/instance.py @@ -2,7 +2,7 @@ import collections import re import time -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type from sky import sky_logging from sky.adaptors import gcp @@ -175,7 +175,7 @@ def terminate_instances( def open_ports( cluster_name_on_cloud: str, - ports: List[Union[int, str]], + ports: List[str], provider_config: Optional[Dict[str, Any]] = None, ) -> None: """See sky/provision/__init__.py""" diff --git a/sky/provision/gcp/instance_utils.py b/sky/provision/gcp/instance_utils.py index fbbae2ddb19..f0f0e1fd321 100644 --- a/sky/provision/gcp/instance_utils.py +++ b/sky/provision/gcp/instance_utils.py @@ -1,5 +1,5 @@ """Utilities for GCP instances.""" -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from sky import sky_logging from sky.adaptors import gcp @@ -76,7 +76,7 @@ def delete_firewall_rule( cls, project_id: str, cluster_name_on_cloud: str, - port: Union[int, str], + port: str, ) -> None: raise NotImplementedError @@ -94,7 +94,7 @@ def create_firewall_rule( cls, project_id: str, cluster_name_on_cloud: str, - port: Union[int, str], + port: str, vpc_name: str, ) -> dict: raise NotImplementedError @@ -102,7 +102,7 @@ def create_firewall_rule( def _get_firewall_rule_name( cluster_name_on_cloud: str, - port: Union[int, str], + port: str, ) -> str: return f'user-ports-{cluster_name_on_cloud}-{port}' @@ -230,7 +230,7 @@ def delete_firewall_rule( cls, project_id: str, cluster_name_on_cloud: str, - port: Union[int, str], + port: str, ) -> None: firewall_rule_name = _get_firewall_rule_name(cluster_name_on_cloud, port) @@ -269,7 +269,7 @@ def create_firewall_rule( cls, project_id: str, cluster_name_on_cloud: str, - port: Union[int, str], + port: str, vpc_name: str, ) -> dict: name = _get_firewall_rule_name(cluster_name_on_cloud, port) @@ -282,7 +282,7 @@ def create_firewall_rule( 'priority': 65534, 'allowed': [{ 'IPProtocol': 'tcp', - 'ports': [str(port)], + 'ports': [port], },], 'sourceRanges': ['0.0.0.0/0'], 'targetTags': [cluster_name_on_cloud], diff --git a/sky/resources.py b/sky/resources.py index 2ab21df0314..24a5c6a4f45 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -42,7 +42,7 @@ class Resources: """ # If any fields changed, increment the version. For backward compatibility, # modify the __setstate__ method to handle the old version. - _VERSION = 12 + _VERSION = 13 def __init__( self, @@ -59,7 +59,7 @@ def __init__( image_id: Union[Dict[str, str], str, None] = None, disk_size: Optional[int] = None, disk_tier: Optional[Literal['high', 'medium', 'low']] = None, - ports: Optional[List[Union[int, str]]] = None, + ports: Optional[List[str]] = None, # Internal use only. _docker_login_config: Optional[command_runner.DockerLoginConfig] = None, _is_image_managed: Optional[bool] = None, @@ -169,6 +169,8 @@ def __init__( self._is_image_managed = _is_image_managed self._disk_tier = disk_tier + if ports is not None: + ports = [str(port) for port in ports] self._ports = ports self._docker_login_config = _docker_login_config @@ -367,7 +369,7 @@ def disk_tier(self) -> str: return self._disk_tier @property - def ports(self) -> Optional[List[Union[int, str]]]: + def ports(self) -> Optional[List[str]]: return self._ports @property @@ -800,33 +802,35 @@ def _try_validate_ports(self) -> None: self.cloud.check_features_are_supported( {clouds.CloudImplementationFeatures.OPEN_PORTS}) for port in self.ports: - if isinstance(port, int): - if port < 1 or port > 65535: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Invalid port {port}. Please use a port number ' - 'between 1 and 65535.') - elif isinstance(port, str): - port_range = port.split('-') - if len(port_range) != 2: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Invalid port {port}. Please use a port range ' - 'such as 10022-10040.') - try: - from_port = int(port_range[0]) - to_port = int(port_range[1]) - except ValueError as e: - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Invalid port {port}. Please use a integer inside' - ' the range.') from e - if (from_port < 1 or from_port > 65535 or to_port < 1 or - to_port > 65535): - with ux_utils.print_exception_no_traceback(): - raise ValueError( - f'Invalid port {port}. Please use port ' - 'numbers between 1 and 65535.') + if isinstance(port, str): + if port.isdigit(): + int_port = int(port) + if int_port < 1 or int_port > 65535: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Invalid port {port}. Please use a port ' + 'number between 1 and 65535.') + else: + port_range = port.split('-') + if len(port_range) != 2: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Invalid port {port}. Please use a port ' + 'range such as 10022-10040.') + try: + from_port = int(port_range[0]) + to_port = int(port_range[1]) + except ValueError as e: + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Invalid port {port}. Please use a integer ' + 'inside the range.') from e + if (from_port < 1 or from_port > 65535 or to_port < 1 or + to_port > 65535): + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Invalid port {port}. Please use port ' + 'numbers between 1 and 65535.') else: with ux_utils.print_exception_no_traceback(): raise ValueError( @@ -1213,4 +1217,7 @@ def __setstate__(self, state): if version < 12: self._docker_login_config = None + if version < 13: + state['_ports'] = [str(port) for port in state['_ports']] + self.__dict__.update(state) diff --git a/sky/utils/resources_utils.py b/sky/utils/resources_utils.py index eaa0a2fba85..dd1bb43d5d4 100644 --- a/sky/utils/resources_utils.py +++ b/sky/utils/resources_utils.py @@ -1,32 +1,32 @@ """Utility functions for resources.""" import itertools -from typing import List, Set, Union +from typing import List, Set # TODO(tian): Maybe we need more intuitive names for these functions. -def parse_ports(ports: List[Union[int, str]]) -> Set[int]: +def parse_ports(ports: List[str]) -> Set[int]: """Parse a list of ports into a set that containing no duplicates. For example, ['1-3', '5-7'] will be parsed to {1, 2, 3, 5, 6, 7}. """ port_set = set() - for p in ports: - if isinstance(p, int): - port_set.add(p) + for port in ports: + if port.isdigit(): + port_set.add(int(port)) else: - from_port, to_port = p.split('-') + from_port, to_port = port.split('-') port_set.update(range(int(from_port), int(to_port) + 1)) return port_set -def parse_port_set(port_set: Set[int]) -> List[Union[int, str]]: +def parse_port_set(port_set: Set[int]) -> List[str]: """Parse a set of ports into the skypilot ports format. This function will group consecutive ports together into a range, and keep the rest as is. For example, {1, 2, 3, 5, 6, 7} will be parsed to ['1-3', '5-7']. """ - ports: List[Union[int, str]] = [] + ports: List[str] = [] # Group consecutive ports together. # This algorithm is based on one observation: consecutive numbers # in a sorted list will have the same difference with their indices. @@ -37,15 +37,15 @@ def parse_port_set(port_set: Set[int]) -> List[Union[int, str]]: lambda x: x[1] - x[0]): port = [g[1] for g in group] if len(port) == 1: - ports.append(port[0]) + ports.append(str(port[0])) else: ports.append(f'{port[0]}-{port[-1]}') return ports -def simplify_ports(ports: List[Union[int, str]]) -> List[Union[int, str]]: +def simplify_ports(ports: List[str]) -> List[str]: """Simplify a list of ports. - For example, [1, 2, 3, '5-6', 7] will be simplified to ['1-3', '5-7']. + For example, ['1-2', '3', '5-6', '7'] will be simplified to ['1-3', '5-7']. """ return parse_port_set(parse_ports(ports))