Skip to content

Commit

Permalink
move sg / firewall name generation into make_deploy_variables
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Sep 17, 2023
1 parent 3b33aac commit d172b84
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 34 deletions.
18 changes: 6 additions & 12 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,10 @@ def write_cluster_config(
# is running a job with less resources than the cluster has.
cloud = to_provision.cloud
assert cloud is not None, to_provision

cluster_name_on_cloud = common_utils.make_cluster_name_on_cloud(
cluster_name, max_length=cloud.max_cluster_name_length())

# This can raise a ResourcesUnavailableError when:
# * The region/zones requested does not appear in the catalog. It can be
# triggered if the user changed the catalog file while there is a cluster
Expand All @@ -909,7 +913,8 @@ def write_cluster_config(
# move the check out of this function, i.e. the caller should be responsible
# for the validation.
# TODO(tian): Move more cloud agnostic vars to resources.py.
resources_vars = to_provision.make_deploy_variables(region, zones)
resources_vars = to_provision.make_deploy_variables(cluster_name_on_cloud,
region, zones)
config_dict = {}

azure_subscription_id = None
Expand Down Expand Up @@ -990,9 +995,6 @@ def write_cluster_config(
f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w"))\''
)

cluster_name_on_cloud = common_utils.make_cluster_name_on_cloud(
cluster_name, max_length=cloud.max_cluster_name_length())

# Use a tmp file path to avoid incomplete YAML file being re-used in the
# future.
tmp_yaml_path = yaml_path + '.tmp'
Expand All @@ -1011,14 +1013,6 @@ def write_cluster_config(
'SKYPILOT_USER', '')),

# AWS only:
# Temporary measure, as deleting per-cluster SGs is too slow.
# See https://github.com/skypilot-org/skypilot/pull/742.
# Generate the name of the security group we're looking for.
# (username, last 4 chars of hash of hostname): for uniquefying
# users on shared-account scenarios.
'security_group': skypilot_config.get_nested(
('aws', 'security_group_name'),
clouds.aws.DEFAULT_SECURITY_GROUP_NAME),
'vpc_name': skypilot_config.get_nested(('aws', 'vpc_name'),
None),
'use_internal_ips': skypilot_config.get_nested(
Expand Down
26 changes: 25 additions & 1 deletion sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sky import exceptions
from sky import provision as provision_lib
from sky import sky_logging
from sky import skypilot_config
from sky.adaptors import aws
from sky.clouds import service_catalog
from sky.utils import common_utils
Expand Down Expand Up @@ -48,7 +49,15 @@
]

DEFAULT_AMI_GB = 45

# Temporary measure, as deleting per-cluster SGs is too slow.
# See https://github.com/skypilot-org/skypilot/pull/742.
# Generate the name of the security group we're looking for.
# (username, last 4 chars of hash of hostname): for uniquefying
# users on shared-account scenarios.
DEFAULT_SECURITY_GROUP_NAME = f'sky-sg-{common_utils.user_and_hostname_hash()}'
# Security group to use when user specified ports in their resources.
USER_PORTS_SECURITY_GROUP_NAME = 'sky-sg-{}'


class AWSIdentityType(enum.Enum):
Expand Down Expand Up @@ -336,7 +345,8 @@ def get_vcpus_mem_from_instance_type(
clouds='aws')

def make_deploy_resources_variables(
self, resources: 'resources_lib.Resources', region: 'clouds.Region',
self, resources: 'resources_lib.Resources',
cluster_name_on_cloud: str, region: 'clouds.Region',
zones: Optional[List['clouds.Zone']]) -> Dict[str, Any]:
assert zones is not None, (region, zones)

Expand All @@ -358,13 +368,27 @@ def make_deploy_resources_variables(
image_id = self._get_image_id(image_id_to_use, region_name,
r.instance_type)

user_security_group = skypilot_config.get_nested(
('aws', 'security_group_name'), None)
if resources.ports is not None:
# Already checked in Resources._try_validate_ports
assert user_security_group is None
security_group = USER_PORTS_SECURITY_GROUP_NAME.format(
cluster_name_on_cloud)
elif user_security_group is not None:
assert resources.ports is None
security_group = user_security_group
else:
security_group = DEFAULT_SECURITY_GROUP_NAME

return {
'instance_type': r.instance_type,
'custom_resources': custom_resources,
'use_spot': r.use_spot,
'region': region_name,
'zones': ','.join(zone_names),
'image_id': image_id,
'security_group': security_group,
**AWS._get_disk_specs(r.disk_tier)
}

Expand Down
4 changes: 3 additions & 1 deletion sky/clouds/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(
self, resources: 'resources.Resources', region: 'clouds.Region',
self, resources: 'resources.Resources', cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud # Unused.
assert zones is None, ('Azure does not support zones', zones)

region_name = region.name
Expand Down
1 change: 1 addition & 0 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def is_same_cloud(self, other):
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
region: 'Region',
zones: Optional[List['Zone']],
) -> Dict[str, Optional[str]]:
Expand Down
11 changes: 10 additions & 1 deletion sky/clouds/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
# TODO(zhwu): Move the default AMI size to the catalog instead.
DEFAULT_GCP_IMAGE_GB = 50

USER_PORTS_FIREWALL_RULE_NAME = 'sky-ports-{}'


def _run_output(cmd):
proc = subprocess.run(cmd,
Expand Down Expand Up @@ -411,7 +413,8 @@ def get_default_instance_type(
clouds='gcp')

def make_deploy_resources_variables(
self, resources: 'resources.Resources', region: 'clouds.Region',
self, resources: 'resources.Resources', cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
assert zones is not None, (region, zones)

Expand Down Expand Up @@ -493,6 +496,12 @@ def make_deploy_resources_variables(

resources_vars['disk_tier'] = GCP._get_disk_type(r.disk_tier)

firewall_rule = None
if resources.ports is not None:
firewall_rule = (
USER_PORTS_FIREWALL_RULE_NAME.format(cluster_name_on_cloud))
resources_vars['firewall_rule'] = firewall_rule

return resources_vars

def _get_feasible_launchable_resources(
Expand Down
2 changes: 2 additions & 0 deletions sky/clouds/ibm.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def is_same_cloud(self, other):
def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name_on_cloud: str,
region: 'clouds.Region',
zones: Optional[List['clouds.Zone']],
) -> Dict[str, Optional[str]]:
Expand All @@ -177,6 +178,7 @@ def make_deploy_resources_variables(
Returns:
A dictionary of cloud-specific node type variables.
"""
del cluster_name_on_cloud # Unused.

def _get_profile_resources(instance_profile):
"""returns a dict representing the
Expand Down
4 changes: 2 additions & 2 deletions sky/clouds/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ def get_zone_shell_cmd(cls) -> Optional[str]:

def make_deploy_resources_variables(
self, resources: 'resources_lib.Resources',
region: Optional['clouds.Region'],
cluster_name_on_cloud: str, region: Optional['clouds.Region'],
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
del zones
del cluster_name_on_cloud, zones # Unused.
if region is None:
region = self._regions[0]

Expand Down
4 changes: 3 additions & 1 deletion sky/clouds/lambda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,10 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(
self, resources: 'resources_lib.Resources', region: 'clouds.Region',
self, resources: 'resources_lib.Resources',
cluster_name_on_cloud: str, region: 'clouds.Region',
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud # Unused.
assert zones is None, 'Lambda does not support zones.'

r = resources
Expand Down
2 changes: 1 addition & 1 deletion sky/clouds/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def regions(cls) -> List[clouds.Region]:

def make_deploy_resources_variables(
self, resources: 'resources_lib.Resources',
region: Optional['clouds.Region'],
cluster_name_on_cloud: str, region: Optional['clouds.Region'],
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
return {}

Expand Down
3 changes: 2 additions & 1 deletion sky/clouds/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def get_zone_shell_cmd(cls) -> Optional[str]:

def make_deploy_resources_variables(
self, resources: 'resources_lib.Resources',
region: Optional['clouds.Region'],
cluster_name_on_cloud: str, region: Optional['clouds.Region'],
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud # Unused.
assert region is not None, resources

acc_dict = self.get_accelerators_from_instance_type(
Expand Down
4 changes: 3 additions & 1 deletion sky/clouds/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,10 @@ def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(
self, resources: 'resources_lib.Resources', region: 'clouds.Region',
self, resources: 'resources_lib.Resources',
cluster_name_on_cloud: str, region: 'clouds.Region',
zones: Optional[List['clouds.Zone']]) -> Dict[str, Optional[str]]:
del cluster_name_on_cloud # Unused.
assert zones is None, 'SCP does not support zones.'

r = resources
Expand Down
1 change: 1 addition & 0 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def cleanup_ports(
project_id = provider_config['project_id']
if 'ports' in provider_config:
# Backward compatibility for old provider config.
# TODO(tian): remove this after 2 minor releases, 0.6.0.
for port in provider_config['ports']:
firewall_rule_name = f'user-ports-{cluster_name_on_cloud}-{port}'
instance_utils.GCPComputeInstance.delete_firewall_rule(
Expand Down
8 changes: 3 additions & 5 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def get_cost(self, seconds: float) -> float:
return hourly_cost * hours

def make_deploy_variables(
self, region: clouds.Region,
self, cluster_name_on_cloud: str, region: clouds.Region,
zones: Optional[List[clouds.Zone]]) -> Dict[str, Optional[str]]:
"""Converts planned sky.Resources to resource variables.
Expand All @@ -837,7 +837,7 @@ def make_deploy_variables(
variables are generated by this method.
"""
cloud_specific_variables = self.cloud.make_deploy_resources_variables(
self, region, zones)
self, cluster_name_on_cloud, region, zones)
docker_image = self.extract_docker_image()
return dict(
cloud_specific_variables,
Expand All @@ -852,9 +852,7 @@ def make_deploy_variables(
constants.DEFAULT_DOCKER_CONTAINER_NAME,
# Docker login config (if any). This helps pull the image from
# private registries.
'docker_login_config': self._docker_login_config,
# The cluster have ports requirement.
'have_ports': self.ports is not None,
'docker_login_config': self._docker_login_config
})

def get_reservations_available_resources(
Expand Down
7 changes: 1 addition & 6 deletions sky/templates/aws-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@ provider:
# teardown(terminate=True) will override this.
cache_stopped_nodes: True
security_group:
# AWS config file must include security group name, but we change
# to dedicted security group if we have ports requirement.
{% if have_ports %}
GroupName: sky-sg-{{cluster_name_on_cloud}}
{% else %}
# AWS config file must include security group name
GroupName: {{security_group}}
{% endif %}
{% if vpc_name is not none %}
# NOTE: This is a new field added by SkyPilot and parsed by our own
# AWSNodeProvider.
Expand Down
4 changes: 2 additions & 2 deletions sky/templates/gcp-ray.yml.j2
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ provider:
project_id: {{gcp_project_id}}
# The firewall rule name for customized firewall rules. Only enabled
# if we have ports requirement.
{% if have_ports %}
firewall_rule: sky-ports-{{cluster_name_on_cloud}}
{% if firewall_rule is not None %}
firewall_rule: {{firewall_rule}}
{% endif %}
{%- if docker_login_config is not none %}
# We put docker login config in provider section because ray's schema disabled
Expand Down

0 comments on commit d172b84

Please sign in to comment.