Skip to content

Commit

Permalink
change port type to List[str]
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Sep 14, 2023
1 parent 985c784 commit 7890348
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 67 deletions.
2 changes: 1 addition & 1 deletion sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def _restore_block(new_block: Dict[str, Any], old_block: Dict[str, Any]):
def write_cluster_config(
to_provision: 'resources.Resources',
num_nodes: int,
ports: Optional[List[Union[int, str]]],
ports: Optional[List[str]],
cluster_config_template: str,
cluster_name: str,
local_wheel_path: pathlib.Path,
Expand Down
7 changes: 3 additions & 4 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def __init__(
cluster_name: str,
resources: resources_lib.Resources,
num_nodes: int,
ports_to_open: Optional[List[Union[int, str]]] = None,
ports_to_open: Optional[List[str]] = None,
prev_cluster_status: Optional[status_lib.ClusterStatus] = None,
prev_handle: Optional['CloudVmRayResourceHandle'] = None,
) -> None:
Expand Down Expand Up @@ -2875,7 +2875,7 @@ def _get_zone(runner):
return handle

def _open_inexistent_ports(self, handle: CloudVmRayResourceHandle,
ports_to_open: List[Union[int, str]]) -> None:
ports_to_open: List[str]) -> None:
cloud = handle.launched_resources.cloud
if not isinstance(cloud, (clouds.AWS, clouds.GCP, clouds.Azure)):
logger.warning(f'Cannot open ports for {cloud} that not support '
Expand All @@ -2891,8 +2891,7 @@ def _update_after_cluster_provisioned(
self, handle: CloudVmRayResourceHandle, task: task_lib.Task,
prev_cluster_status: Optional[status_lib.ClusterStatus],
ip_list: List[str], ssh_port_list: List[int],
ports_to_open: Optional[List[Union[int,
str]]], lock_path: str) -> None:
ports_to_open: Optional[List[str]], lock_path: str) -> None:
usage_lib.messages.usage.update_cluster_resources(
handle.launched_nodes, handle.launched_resources)
usage_lib.messages.usage.update_final_cluster_status(
Expand Down
2 changes: 1 addition & 1 deletion sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def terminate_instances(
def open_ports(
provider_name: str,
cluster_name_on_cloud: str,
ports: List[Union[int, str]],
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""Open ports for inbound traffic."""
Expand Down
14 changes: 6 additions & 8 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""AWS instance provisioning."""
import re
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

from botocore import config

Expand Down Expand Up @@ -227,7 +227,7 @@ def _maybe_move_to_new_sg(

def open_ports(
cluster_name_on_cloud: str,
ports: List[Union[int, str]],
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
Expand All @@ -242,15 +242,13 @@ def open_ports(

ip_permissions = []
for port in ports:
if isinstance(port, int):
if port.isdigit():
from_port = to_port = port
else:
from_to_port = port.split('-')
from_port = int(from_to_port[0])
to_port = int(from_to_port[1])
from_port, to_port = port.split('-')
ip_permissions.append({
'FromPort': from_port,
'ToPort': to_port,
'FromPort': int(from_port),
'ToPort': int(to_port),
'IpProtocol': 'tcp',
'IpRanges': [{
'CidrIp': '0.0.0.0/0'
Expand Down
5 changes: 2 additions & 3 deletions sky/provision/azure/instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Azure instance provisioning."""
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

from sky import sky_logging
from sky.adaptors import azure
Expand Down Expand Up @@ -29,7 +29,7 @@ def get_azure_sdk_function(client: Any, function_name: str) -> Callable:

def open_ports(
cluster_name_on_cloud: str,
ports: List[Union[int, str]],
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
Expand All @@ -39,7 +39,6 @@ def open_ports(
network_client = azure.get_client('network', subscription_id)
create_or_update = get_azure_sdk_function(
client=network_client.security_rules, function_name='create_or_update')
ports = [str(port) for port in ports if port != 22]
rule_name = f'user-ports-{"-".join(ports)}'

def security_rule_parameters(priority: int) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import collections
import re
import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Type

from sky import sky_logging
from sky.adaptors import gcp
Expand Down Expand Up @@ -175,7 +175,7 @@ def terminate_instances(

def open_ports(
cluster_name_on_cloud: str,
ports: List[Union[int, str]],
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
Expand Down
14 changes: 7 additions & 7 deletions sky/provision/gcp/instance_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Utilities for GCP instances."""
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from sky import sky_logging
from sky.adaptors import gcp
Expand Down Expand Up @@ -76,7 +76,7 @@ def delete_firewall_rule(
cls,
project_id: str,
cluster_name_on_cloud: str,
port: Union[int, str],
port: str,
) -> None:
raise NotImplementedError

Expand All @@ -94,15 +94,15 @@ def create_firewall_rule(
cls,
project_id: str,
cluster_name_on_cloud: str,
port: Union[int, str],
port: str,
vpc_name: str,
) -> dict:
raise NotImplementedError


def _get_firewall_rule_name(
cluster_name_on_cloud: str,
port: Union[int, str],
port: str,
) -> str:
return f'user-ports-{cluster_name_on_cloud}-{port}'

Expand Down Expand Up @@ -230,7 +230,7 @@ def delete_firewall_rule(
cls,
project_id: str,
cluster_name_on_cloud: str,
port: Union[int, str],
port: str,
) -> None:
firewall_rule_name = _get_firewall_rule_name(cluster_name_on_cloud,
port)
Expand Down Expand Up @@ -269,7 +269,7 @@ def create_firewall_rule(
cls,
project_id: str,
cluster_name_on_cloud: str,
port: Union[int, str],
port: str,
vpc_name: str,
) -> dict:
name = _get_firewall_rule_name(cluster_name_on_cloud, port)
Expand All @@ -282,7 +282,7 @@ def create_firewall_rule(
'priority': 65534,
'allowed': [{
'IPProtocol': 'tcp',
'ports': [str(port)],
'ports': [port],
},],
'sourceRanges': ['0.0.0.0/0'],
'targetTags': [cluster_name_on_cloud],
Expand Down
67 changes: 37 additions & 30 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Resources:
"""
# If any fields changed, increment the version. For backward compatibility,
# modify the __setstate__ method to handle the old version.
_VERSION = 12
_VERSION = 13

def __init__(
self,
Expand All @@ -59,7 +59,7 @@ def __init__(
image_id: Union[Dict[str, str], str, None] = None,
disk_size: Optional[int] = None,
disk_tier: Optional[Literal['high', 'medium', 'low']] = None,
ports: Optional[List[Union[int, str]]] = None,
ports: Optional[List[str]] = None,
# Internal use only.
_docker_login_config: Optional[command_runner.DockerLoginConfig] = None,
_is_image_managed: Optional[bool] = None,
Expand Down Expand Up @@ -169,6 +169,8 @@ def __init__(
self._is_image_managed = _is_image_managed

self._disk_tier = disk_tier
if ports is not None:
ports = [str(port) for port in ports]
self._ports = ports
self._docker_login_config = _docker_login_config

Expand Down Expand Up @@ -367,7 +369,7 @@ def disk_tier(self) -> str:
return self._disk_tier

@property
def ports(self) -> Optional[List[Union[int, str]]]:
def ports(self) -> Optional[List[str]]:
return self._ports

@property
Expand Down Expand Up @@ -800,33 +802,35 @@ def _try_validate_ports(self) -> None:
self.cloud.check_features_are_supported(
{clouds.CloudImplementationFeatures.OPEN_PORTS})
for port in self.ports:
if isinstance(port, int):
if port < 1 or port > 65535:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use a port number '
'between 1 and 65535.')
elif isinstance(port, str):
port_range = port.split('-')
if len(port_range) != 2:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use a port range '
'such as 10022-10040.')
try:
from_port = int(port_range[0])
to_port = int(port_range[1])
except ValueError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use a integer inside'
' the range.') from e
if (from_port < 1 or from_port > 65535 or to_port < 1 or
to_port > 65535):
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use port '
'numbers between 1 and 65535.')
if isinstance(port, str):
if port.isdigit():
int_port = int(port)
if int_port < 1 or int_port > 65535:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use a port '
'number between 1 and 65535.')
else:
port_range = port.split('-')
if len(port_range) != 2:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use a port '
'range such as 10022-10040.')
try:
from_port = int(port_range[0])
to_port = int(port_range[1])
except ValueError as e:
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use a integer '
'inside the range.') from e
if (from_port < 1 or from_port > 65535 or to_port < 1 or
to_port > 65535):
with ux_utils.print_exception_no_traceback():
raise ValueError(
f'Invalid port {port}. Please use port '
'numbers between 1 and 65535.')
else:
with ux_utils.print_exception_no_traceback():
raise ValueError(
Expand Down Expand Up @@ -1213,4 +1217,7 @@ def __setstate__(self, state):
if version < 12:
self._docker_login_config = None

if version < 13:
state['_ports'] = [str(port) for port in state['_ports']]

self.__dict__.update(state)
22 changes: 11 additions & 11 deletions sky/utils/resources_utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
"""Utility functions for resources."""
import itertools
from typing import List, Set, Union
from typing import List, Set


# TODO(tian): Maybe we need more intuitive names for these functions.
def parse_ports(ports: List[Union[int, str]]) -> Set[int]:
def parse_ports(ports: List[str]) -> Set[int]:
"""Parse a list of ports into a set that containing no duplicates.
For example, ['1-3', '5-7'] will be parsed to {1, 2, 3, 5, 6, 7}.
"""
port_set = set()
for p in ports:
if isinstance(p, int):
port_set.add(p)
for port in ports:
if port.isdigit():
port_set.add(int(port))
else:
from_port, to_port = p.split('-')
from_port, to_port = port.split('-')
port_set.update(range(int(from_port), int(to_port) + 1))
return port_set


def parse_port_set(port_set: Set[int]) -> List[Union[int, str]]:
def parse_port_set(port_set: Set[int]) -> List[str]:
"""Parse a set of ports into the skypilot ports format.
This function will group consecutive ports together into a range,
and keep the rest as is. For example, {1, 2, 3, 5, 6, 7} will be
parsed to ['1-3', '5-7'].
"""
ports: List[Union[int, str]] = []
ports: List[str] = []
# Group consecutive ports together.
# This algorithm is based on one observation: consecutive numbers
# in a sorted list will have the same difference with their indices.
Expand All @@ -37,15 +37,15 @@ def parse_port_set(port_set: Set[int]) -> List[Union[int, str]]:
lambda x: x[1] - x[0]):
port = [g[1] for g in group]
if len(port) == 1:
ports.append(port[0])
ports.append(str(port[0]))
else:
ports.append(f'{port[0]}-{port[-1]}')
return ports


def simplify_ports(ports: List[Union[int, str]]) -> List[Union[int, str]]:
def simplify_ports(ports: List[str]) -> List[str]:
"""Simplify a list of ports.
For example, [1, 2, 3, '5-6', 7] will be simplified to ['1-3', '5-7'].
For example, ['1-2', '3', '5-6', '7'] will be simplified to ['1-3', '5-7'].
"""
return parse_port_set(parse_ports(ports))

0 comments on commit 7890348

Please sign in to comment.