Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Relax cluster name restriction and process cloud cluster name #3130

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,7 @@ def provision_with_retries(
try:
# Recheck cluster name as the 'except:' block below may
# change the cloud assignment.
to_provision.cloud.check_cluster_name_is_valid(cluster_name)
common_utils.check_cluster_name_is_valid(cluster_name)
if dryrun:
cloud_user = None
else:
Expand Down Expand Up @@ -4342,10 +4342,7 @@ def _check_existing_cluster(
usage_lib.messages.usage.set_new_cluster()
# Use the task_cloud, because the cloud in `to_provision` can be changed
# later during the retry.
for resources in task.resources:
task_cloud = (resources.cloud
if resources.cloud is not None else clouds.Cloud)
task_cloud.check_cluster_name_is_valid(cluster_name)
common_utils.check_cluster_name_is_valid(cluster_name)

if to_provision is None:
# The cluster is recently terminated either by autostop or manually
Expand Down
11 changes: 1 addition & 10 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4074,16 +4074,7 @@ def spot_launch(
if prompt is not None:
click.confirm(prompt, default=True, abort=True, show_default=True)

for task in dag.tasks:
# We try our best to validate the cluster name before we launch the
# task. If the cloud is not specified, this will only validate the
# cluster name against the regex, and the cloud-specific validation will
# be done by the spot controller when actually launching the spot
# cluster.
for resources in task.resources:
task_cloud = (resources.cloud
if resources.cloud is not None else clouds.Cloud)
task_cloud.check_cluster_name_is_valid(name)
common_utils.check_cluster_name_is_valid(name)

sky.spot_launch(dag,
name,
Expand Down
24 changes: 0 additions & 24 deletions sky/clouds/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
"""
import collections
import enum
import re
import typing
from typing import Dict, Iterator, List, Optional, Set, Tuple

from sky import exceptions
from sky import skypilot_config
from sky.clouds import service_catalog
from sky.skylet import constants
from sky.utils import log_utils
from sky.utils import resources_utils
from sky.utils import ux_utils
Expand Down Expand Up @@ -536,28 +534,6 @@ def _unsupported_features_for_resources(
del resources
raise NotImplementedError

@classmethod
def check_cluster_name_is_valid(cls, cluster_name: str) -> None:
"""Errors out on invalid cluster names not supported by cloud providers.
Bans (including but not limited to) names that:
- are digits-only
- contain underscore (_)
Raises:
exceptions.InvalidClusterNameError: If the cluster name is invalid.
"""
if cluster_name is None:
return
valid_regex = constants.CLUSTER_NAME_VALID_REGEX
if re.fullmatch(valid_regex, cluster_name) is None:
with ux_utils.print_exception_no_traceback():
raise exceptions.InvalidClusterNameError(
f'Cluster name "{cluster_name}" is invalid; '
'ensure it is fully matched by regex (e.g., '
'only contains lower letters, numbers and dash): '
f'{valid_regex}')

@classmethod
def check_disk_tier_enabled(cls, instance_type: Optional[str],
disk_tier: resources_utils.DiskTier) -> None:
Expand Down
2 changes: 1 addition & 1 deletion sky/skylet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@

# In most clouds, cluster names can only contain lowercase letters, numbers
# and hyphens. We use this regex to validate the cluster name.
CLUSTER_NAME_VALID_REGEX = '[a-z]([-a-z0-9]*[a-z0-9])?'
CLUSTER_NAME_VALID_REGEX = '[a-zA-Z]([-_.a-zA-Z0-9]*[a-zA-Z0-9])?'

# Used for translate local file mounts to cloud storage. Please refer to
# sky/execution.py::_maybe_translate_local_file_mounts_and_sync_up for
Expand Down
65 changes: 50 additions & 15 deletions sky/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jsonschema
import yaml

from sky import exceptions
from sky import sky_logging
from sky.skylet import constants
from sky.utils import ux_utils
Expand Down Expand Up @@ -117,44 +118,77 @@ def _base36_encode(num: int) -> str:
return _base36_encode(int_value)


def make_cluster_name_on_cloud(cluster_name: str,
def check_cluster_name_is_valid(cluster_name: Optional[str]) -> None:
"""Errors out on invalid cluster names.

Bans (including but not limited to) names that:
- are digits-only
- start with invalid character, like hyphen

Raises:
exceptions.InvalidClusterNameError: If the cluster name is invalid.
"""
if cluster_name is None:
return
valid_regex = constants.CLUSTER_NAME_VALID_REGEX
if re.fullmatch(valid_regex, cluster_name) is None:
with ux_utils.print_exception_no_traceback():
raise exceptions.InvalidClusterNameError(
f'Cluster name "{cluster_name}" is invalid; '
'ensure it is fully matched by regex (e.g., '
'only contains letters, numbers and dash): '
f'{valid_regex}')


def make_cluster_name_on_cloud(display_name: str,
max_length: Optional[int] = 15,
add_user_hash: bool = True) -> str:
"""Generate valid cluster name on cloud that is unique to the user.

This is to map the cluster name to a valid length for cloud providers, e.g.
GCP limits the length of the cluster name to 35 characters. If the cluster
name with user hash is longer than max_length:
This is to map the cluster name to a valid length and character set for
cloud providers,
- e.g. GCP limits the length of the cluster name to 35 characters. If the
cluster name with user hash is longer than max_length:
1. Truncate it to max_length - cluster_hash - user_hash_length.
2. Append the hash of the cluster name
dtran24 marked this conversation as resolved.
Show resolved Hide resolved
- e.g. some cloud providers don't allow for uppercase letters, periods,
or underscores, so we convert it to lower case and replace those
characters with hyphens

Args:
cluster_name: The cluster name to be truncated and hashed.
display_name: The cluster name to be truncated, hashed, and
transformed.
max_length: The maximum length of the cluster name. If None, no
truncation is performed.
add_user_hash: Whether to append user hash to the cluster name.
"""

cluster_name_on_cloud = re.sub(r'[._]', '-', display_name).lower()
if display_name != cluster_name_on_cloud:
logger.debug(
f'The user specified cluster name {display_name} might be invalid '
f'on the cloud, we convert it to {cluster_name_on_cloud}.')
user_hash = ''
if add_user_hash:
user_hash = get_user_hash()[:USER_HASH_LENGTH_IN_CLUSTER_NAME]
user_hash = f'-{user_hash}'
user_hash_length = len(user_hash)

if (max_length is None or
len(cluster_name) <= max_length - user_hash_length):
return f'{cluster_name}{user_hash}'
len(cluster_name_on_cloud) <= max_length - user_hash_length):
return f'{cluster_name_on_cloud}{user_hash}'
# -1 is for the dash between cluster name and cluster name hash.
truncate_cluster_name_length = (max_length - CLUSTER_NAME_HASH_LENGTH - 1 -
user_hash_length)
truncate_cluster_name = cluster_name[:truncate_cluster_name_length]
truncate_cluster_name = cluster_name_on_cloud[:truncate_cluster_name_length]
if truncate_cluster_name.endswith('-'):
truncate_cluster_name = truncate_cluster_name.rstrip('-')
assert truncate_cluster_name_length > 0, (cluster_name, max_length)
cluster_name_hash = hashlib.md5(cluster_name.encode()).hexdigest()
assert truncate_cluster_name_length > 0, (cluster_name_on_cloud, max_length)
display_name_hash = hashlib.md5(display_name.encode()).hexdigest()
# Use base36 to reduce the length of the hash.
cluster_name_hash = base36_encode(cluster_name_hash)
display_name_hash = base36_encode(display_name_hash)
return (f'{truncate_cluster_name}'
f'-{cluster_name_hash[:CLUSTER_NAME_HASH_LENGTH]}{user_hash}')
f'-{display_name_hash[:CLUSTER_NAME_HASH_LENGTH]}{user_hash}')


def cluster_name_in_hint(cluster_name: str, cluster_name_on_cloud: str) -> str:
Expand Down Expand Up @@ -551,8 +585,9 @@ def validate_schema(obj, schema, err_msg_prefix='', skip_none=True):


def get_cleaned_username(username: str = '') -> str:
"""Cleans the username as some cloud provider have limitation on
characters usage such as dot (.) is not allowed in GCP.
"""Cleans the username. Dots and underscores are allowed, as we will
handle it when mapping to the cluster_name_on_cloud in
common_utils.make_cluster_name_on_cloud.

Clean up includes:
1. Making all characters lowercase
Expand All @@ -567,7 +602,7 @@ def get_cleaned_username(username: str = '') -> str:
"""
username = username or getpass.getuser()
username = username.lower()
username = re.sub(r'[^a-z0-9-]', '', username)
username = re.sub(r'[^a-z0-9-._]', '', username)
username = re.sub(r'^[0-9-]+', '', username)
username = re.sub(r'-$', '', username)
return username
Expand Down
51 changes: 51 additions & 0 deletions tests/unit_tests/test_common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from unittest.mock import patch

import pytest

from sky import exceptions
from sky.utils import common_utils

MOCKED_USER_HASH = 'ab12cd34'


class TestCheckClusterNameIsValid:

def test_check(self):
common_utils.check_cluster_name_is_valid("lora")

def test_check_with_hyphen(self):
common_utils.check_cluster_name_is_valid("seed-1")

def test_check_with_characters_to_transform(self):
common_utils.check_cluster_name_is_valid("Cuda_11.8")

def test_check_when_starts_with_number(self):
with pytest.raises(exceptions.InvalidClusterNameError):
common_utils.check_cluster_name_is_valid("11.8cuda")

def test_check_with_invalid_characters(self):
with pytest.raises(exceptions.InvalidClusterNameError):
common_utils.check_cluster_name_is_valid("lor@")

def test_check_when_none(self):
common_utils.check_cluster_name_is_valid(None)


class TestMakeClusterNameOnCloud:

@patch('sky.utils.common_utils.get_user_hash')
def test_make(self, mock_get_user_hash):
mock_get_user_hash.return_value = MOCKED_USER_HASH
assert "lora-ab12" == common_utils.make_cluster_name_on_cloud("lora")

@patch('sky.utils.common_utils.get_user_hash')
def test_make_with_hyphen(self, mock_get_user_hash):
mock_get_user_hash.return_value = MOCKED_USER_HASH
assert "seed-1-ab12" == common_utils.make_cluster_name_on_cloud(
"seed-1")

@patch('sky.utils.common_utils.get_user_hash')
def test_make_with_characters_to_transform(self, mock_get_user_hash):
mock_get_user_hash.return_value = MOCKED_USER_HASH
assert "cuda-11-8-ab12" == common_utils.make_cluster_name_on_cloud(
"Cuda_11.8")
Loading