From 9a77adf8d75a00d6d2e82fecbf3bd4147b06eb3b Mon Sep 17 00:00:00 2001 From: David Tran Date: Mon, 12 Feb 2024 12:11:22 -0500 Subject: [PATCH] [Core] Relax cluster name restriction and process cloud cluster name (#3130) * [core] change regex to allow for uppercase, periods, and underscores for cluster name * [core] process cluster name when making it for the cloud * [core] add tests for make cluster name on cloud * [core] format * [core] add quotations to logged local cluster name * [core] move quotation marks inside parentheses * [core] update param name from local_cluster_name to display name for naming consistency * [core] update docstr * [core] simplify replace and lowercase * [core] use logger.debug * [core] remove unused exception and rename hash var * [core] move check cluster name method to common_utils * [core] update docstr for check cluster name method * [core] fix verb * [core] update docstr and debug for make cluster name on cloud method * [core] update docstr and include dot and underscore for regex --------- Co-authored-by: David Tran --- sky/backends/cloud_vm_ray_backend.py | 7 +-- sky/cli.py | 11 +---- sky/clouds/cloud.py | 24 ---------- sky/skylet/constants.py | 2 +- sky/utils/common_utils.py | 65 ++++++++++++++++++++------- tests/unit_tests/test_common_utils.py | 51 +++++++++++++++++++++ 6 files changed, 105 insertions(+), 55 deletions(-) create mode 100644 tests/unit_tests/test_common_utils.py diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 40435c5f22f..92fc0b135aa 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2043,7 +2043,7 @@ def provision_with_retries( try: # Recheck cluster name as the 'except:' block below may # change the cloud assignment. - to_provision.cloud.check_cluster_name_is_valid(cluster_name) + common_utils.check_cluster_name_is_valid(cluster_name) if dryrun: cloud_user = None else: @@ -4379,10 +4379,7 @@ def _check_existing_cluster( usage_lib.messages.usage.set_new_cluster() # Use the task_cloud, because the cloud in `to_provision` can be changed # later during the retry. - for resources in task.resources: - task_cloud = (resources.cloud - if resources.cloud is not None else clouds.Cloud) - task_cloud.check_cluster_name_is_valid(cluster_name) + common_utils.check_cluster_name_is_valid(cluster_name) if to_provision is None: # The cluster is recently terminated either by autostop or manually diff --git a/sky/cli.py b/sky/cli.py index 030c1473f45..a73edb5ce1d 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -4076,16 +4076,7 @@ def spot_launch( if prompt is not None: click.confirm(prompt, default=True, abort=True, show_default=True) - for task in dag.tasks: - # We try our best to validate the cluster name before we launch the - # task. If the cloud is not specified, this will only validate the - # cluster name against the regex, and the cloud-specific validation will - # be done by the spot controller when actually launching the spot - # cluster. - for resources in task.resources: - task_cloud = (resources.cloud - if resources.cloud is not None else clouds.Cloud) - task_cloud.check_cluster_name_is_valid(name) + common_utils.check_cluster_name_is_valid(name) sky.spot_launch(dag, name, diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 966a44d6c2c..331a7f9a728 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -9,14 +9,12 @@ """ import collections import enum -import re import typing from typing import Dict, Iterator, List, Optional, Set, Tuple from sky import exceptions from sky import skypilot_config from sky.clouds import service_catalog -from sky.skylet import constants from sky.utils import log_utils from sky.utils import resources_utils from sky.utils import ux_utils @@ -536,28 +534,6 @@ def _unsupported_features_for_resources( del resources raise NotImplementedError - @classmethod - def check_cluster_name_is_valid(cls, cluster_name: str) -> None: - """Errors out on invalid cluster names not supported by cloud providers. - - Bans (including but not limited to) names that: - - are digits-only - - contain underscore (_) - - Raises: - exceptions.InvalidClusterNameError: If the cluster name is invalid. - """ - if cluster_name is None: - return - valid_regex = constants.CLUSTER_NAME_VALID_REGEX - if re.fullmatch(valid_regex, cluster_name) is None: - with ux_utils.print_exception_no_traceback(): - raise exceptions.InvalidClusterNameError( - f'Cluster name "{cluster_name}" is invalid; ' - 'ensure it is fully matched by regex (e.g., ' - 'only contains lower letters, numbers and dash): ' - f'{valid_regex}') - @classmethod def check_disk_tier_enabled(cls, instance_type: Optional[str], disk_tier: resources_utils.DiskTier) -> None: diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 1ef51ca2c7b..412d8fc870d 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -93,7 +93,7 @@ # In most clouds, cluster names can only contain lowercase letters, numbers # and hyphens. We use this regex to validate the cluster name. -CLUSTER_NAME_VALID_REGEX = '[a-z]([-a-z0-9]*[a-z0-9])?' +CLUSTER_NAME_VALID_REGEX = '[a-zA-Z]([-_.a-zA-Z0-9]*[a-zA-Z0-9])?' # Used for translate local file mounts to cloud storage. Please refer to # sky/execution.py::_maybe_translate_local_file_mounts_and_sync_up for diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index dd3229308cd..87c7a970afa 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -21,6 +21,7 @@ import jsonschema import yaml +from sky import exceptions from sky import sky_logging from sky.skylet import constants from sky.utils import ux_utils @@ -117,23 +118,56 @@ def _base36_encode(num: int) -> str: return _base36_encode(int_value) -def make_cluster_name_on_cloud(cluster_name: str, +def check_cluster_name_is_valid(cluster_name: Optional[str]) -> None: + """Errors out on invalid cluster names. + + Bans (including but not limited to) names that: + - are digits-only + - start with invalid character, like hyphen + + Raises: + exceptions.InvalidClusterNameError: If the cluster name is invalid. + """ + if cluster_name is None: + return + valid_regex = constants.CLUSTER_NAME_VALID_REGEX + if re.fullmatch(valid_regex, cluster_name) is None: + with ux_utils.print_exception_no_traceback(): + raise exceptions.InvalidClusterNameError( + f'Cluster name "{cluster_name}" is invalid; ' + 'ensure it is fully matched by regex (e.g., ' + 'only contains letters, numbers and dash): ' + f'{valid_regex}') + + +def make_cluster_name_on_cloud(display_name: str, max_length: Optional[int] = 15, add_user_hash: bool = True) -> str: """Generate valid cluster name on cloud that is unique to the user. - This is to map the cluster name to a valid length for cloud providers, e.g. - GCP limits the length of the cluster name to 35 characters. If the cluster - name with user hash is longer than max_length: + This is to map the cluster name to a valid length and character set for + cloud providers, + - e.g. GCP limits the length of the cluster name to 35 characters. If the + cluster name with user hash is longer than max_length: 1. Truncate it to max_length - cluster_hash - user_hash_length. 2. Append the hash of the cluster name + - e.g. some cloud providers don't allow for uppercase letters, periods, + or underscores, so we convert it to lower case and replace those + characters with hyphens Args: - cluster_name: The cluster name to be truncated and hashed. + display_name: The cluster name to be truncated, hashed, and + transformed. max_length: The maximum length of the cluster name. If None, no truncation is performed. add_user_hash: Whether to append user hash to the cluster name. """ + + cluster_name_on_cloud = re.sub(r'[._]', '-', display_name).lower() + if display_name != cluster_name_on_cloud: + logger.debug( + f'The user specified cluster name {display_name} might be invalid ' + f'on the cloud, we convert it to {cluster_name_on_cloud}.') user_hash = '' if add_user_hash: user_hash = get_user_hash()[:USER_HASH_LENGTH_IN_CLUSTER_NAME] @@ -141,20 +175,20 @@ def make_cluster_name_on_cloud(cluster_name: str, user_hash_length = len(user_hash) if (max_length is None or - len(cluster_name) <= max_length - user_hash_length): - return f'{cluster_name}{user_hash}' + len(cluster_name_on_cloud) <= max_length - user_hash_length): + return f'{cluster_name_on_cloud}{user_hash}' # -1 is for the dash between cluster name and cluster name hash. truncate_cluster_name_length = (max_length - CLUSTER_NAME_HASH_LENGTH - 1 - user_hash_length) - truncate_cluster_name = cluster_name[:truncate_cluster_name_length] + truncate_cluster_name = cluster_name_on_cloud[:truncate_cluster_name_length] if truncate_cluster_name.endswith('-'): truncate_cluster_name = truncate_cluster_name.rstrip('-') - assert truncate_cluster_name_length > 0, (cluster_name, max_length) - cluster_name_hash = hashlib.md5(cluster_name.encode()).hexdigest() + assert truncate_cluster_name_length > 0, (cluster_name_on_cloud, max_length) + display_name_hash = hashlib.md5(display_name.encode()).hexdigest() # Use base36 to reduce the length of the hash. - cluster_name_hash = base36_encode(cluster_name_hash) + display_name_hash = base36_encode(display_name_hash) return (f'{truncate_cluster_name}' - f'-{cluster_name_hash[:CLUSTER_NAME_HASH_LENGTH]}{user_hash}') + f'-{display_name_hash[:CLUSTER_NAME_HASH_LENGTH]}{user_hash}') def cluster_name_in_hint(cluster_name: str, cluster_name_on_cloud: str) -> str: @@ -551,8 +585,9 @@ def validate_schema(obj, schema, err_msg_prefix='', skip_none=True): def get_cleaned_username(username: str = '') -> str: - """Cleans the username as some cloud provider have limitation on - characters usage such as dot (.) is not allowed in GCP. + """Cleans the username. Dots and underscores are allowed, as we will + handle it when mapping to the cluster_name_on_cloud in + common_utils.make_cluster_name_on_cloud. Clean up includes: 1. Making all characters lowercase @@ -567,7 +602,7 @@ def get_cleaned_username(username: str = '') -> str: """ username = username or getpass.getuser() username = username.lower() - username = re.sub(r'[^a-z0-9-]', '', username) + username = re.sub(r'[^a-z0-9-._]', '', username) username = re.sub(r'^[0-9-]+', '', username) username = re.sub(r'-$', '', username) return username diff --git a/tests/unit_tests/test_common_utils.py b/tests/unit_tests/test_common_utils.py new file mode 100644 index 00000000000..f38e14069e5 --- /dev/null +++ b/tests/unit_tests/test_common_utils.py @@ -0,0 +1,51 @@ +from unittest.mock import patch + +import pytest + +from sky import exceptions +from sky.utils import common_utils + +MOCKED_USER_HASH = 'ab12cd34' + + +class TestCheckClusterNameIsValid: + + def test_check(self): + common_utils.check_cluster_name_is_valid("lora") + + def test_check_with_hyphen(self): + common_utils.check_cluster_name_is_valid("seed-1") + + def test_check_with_characters_to_transform(self): + common_utils.check_cluster_name_is_valid("Cuda_11.8") + + def test_check_when_starts_with_number(self): + with pytest.raises(exceptions.InvalidClusterNameError): + common_utils.check_cluster_name_is_valid("11.8cuda") + + def test_check_with_invalid_characters(self): + with pytest.raises(exceptions.InvalidClusterNameError): + common_utils.check_cluster_name_is_valid("lor@") + + def test_check_when_none(self): + common_utils.check_cluster_name_is_valid(None) + + +class TestMakeClusterNameOnCloud: + + @patch('sky.utils.common_utils.get_user_hash') + def test_make(self, mock_get_user_hash): + mock_get_user_hash.return_value = MOCKED_USER_HASH + assert "lora-ab12" == common_utils.make_cluster_name_on_cloud("lora") + + @patch('sky.utils.common_utils.get_user_hash') + def test_make_with_hyphen(self, mock_get_user_hash): + mock_get_user_hash.return_value = MOCKED_USER_HASH + assert "seed-1-ab12" == common_utils.make_cluster_name_on_cloud( + "seed-1") + + @patch('sky.utils.common_utils.get_user_hash') + def test_make_with_characters_to_transform(self, mock_get_user_hash): + mock_get_user_hash.return_value = MOCKED_USER_HASH + assert "cuda-11-8-ab12" == common_utils.make_cluster_name_on_cloud( + "Cuda_11.8")