From 5e23f16858cd0446af9d8404dfa8eb79a463f52c Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Wed, 10 Jul 2024 11:21:15 -0700 Subject: [PATCH] [Core] Task level config (#3689) * Add docker run options * Add docs * Add warning for docker run options in kubernetes * Add experimental config * fix * rename vars * type * format * wip * rename and add tests * Fixes and add tests * format * Assert for override configs specification * format * Add comments * fix * fix assertions * fix assertions * Fix test * fix * remove unsupported keys * format --- sky/backends/backend_utils.py | 33 ++----- sky/check.py | 4 +- sky/clouds/gcp.py | 10 +- sky/clouds/kubernetes.py | 22 +++-- sky/provision/kubernetes/utils.py | 22 +++-- sky/resources.py | 53 ++++++++++- sky/skylet/constants.py | 12 +++ sky/skypilot_config.py | 104 +++++++++++++++------ sky/task.py | 19 +++- sky/utils/schemas.py | 78 +++++++++++++++- tests/conftest.py | 2 +- tests/test_config.py | 146 +++++++++++++++++++++++++++++- 12 files changed, 425 insertions(+), 80 deletions(-) diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index a1c86fdb624..b80cf667413 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -874,23 +874,8 @@ def write_cluster_config( f'open(os.path.expanduser("{constants.SKY_REMOTE_RAY_PORT_FILE}"), "w", encoding="utf-8"))\'' ) - # Docker run options - docker_run_options = skypilot_config.get_nested(('docker', 'run_options'), - []) - if isinstance(docker_run_options, str): - docker_run_options = [docker_run_options] - if docker_run_options and isinstance(to_provision.cloud, clouds.Kubernetes): - logger.warning(f'{colorama.Style.DIM}Docker run options are specified, ' - 'but ignored for Kubernetes: ' - f'{" ".join(docker_run_options)}' - f'{colorama.Style.RESET_ALL}') - # Use a tmp file path to avoid incomplete YAML file being re-used in the # future. - initial_setup_commands = [] - if (skypilot_config.get_nested(('nvidia_gpus', 'disable_ecc'), False) and - to_provision.accelerators is not None): - initial_setup_commands.append(constants.DISABLE_GPU_ECC_COMMAND) tmp_yaml_path = yaml_path + '.tmp' common_utils.fill_template( cluster_config_template, @@ -922,8 +907,6 @@ def write_cluster_config( # currently only used by GCP. 'specific_reservations': specific_reservations, - # Initial setup commands. - 'initial_setup_commands': initial_setup_commands, # Conda setup 'conda_installation_commands': constants.CONDA_INSTALLATION_COMMANDS, @@ -935,9 +918,6 @@ def write_cluster_config( wheel_hash).replace('{cloud}', str(cloud).lower())), - # Docker - 'docker_run_options': docker_run_options, - # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, @@ -976,17 +956,20 @@ def write_cluster_config( output_path=tmp_yaml_path) config_dict['cluster_name'] = cluster_name config_dict['ray'] = yaml_path + + # Add kubernetes config fields from ~/.sky/config + if isinstance(cloud, clouds.Kubernetes): + kubernetes_utils.combine_pod_config_fields( + tmp_yaml_path, + cluster_config_overrides=to_provision.cluster_config_overrides) + kubernetes_utils.combine_metadata_fields(tmp_yaml_path) + if dryrun: # If dryrun, return the unfinished tmp yaml path. config_dict['ray'] = tmp_yaml_path return config_dict _add_auth_to_cluster_config(cloud, tmp_yaml_path) - # Add kubernetes config fields from ~/.sky/config - if isinstance(cloud, clouds.Kubernetes): - kubernetes_utils.combine_pod_config_fields(tmp_yaml_path) - kubernetes_utils.combine_metadata_fields(tmp_yaml_path) - # Restore the old yaml content for backward compatibility. if os.path.exists(yaml_path) and keep_launch_fields_in_existing_config: with open(yaml_path, 'r', encoding='utf-8') as f: diff --git a/sky/check.py b/sky/check.py index e8a61317d63..c361c962c94 100644 --- a/sky/check.py +++ b/sky/check.py @@ -77,8 +77,8 @@ def get_all_clouds(): # Use allowed_clouds from config if it exists, otherwise check all clouds. # Also validate names with get_cloud_tuple. config_allowed_cloud_names = [ - get_cloud_tuple(c)[0] for c in skypilot_config.get_nested( - ['allowed_clouds'], get_all_clouds()) + get_cloud_tuple(c)[0] for c in skypilot_config.get_nested(( + 'allowed_clouds',), get_all_clouds()) ] # Use disallowed_cloud_names for logging the clouds that will be disabled # because they are not included in allowed_clouds in config.yaml. diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index f95f6dddfb3..86e9a90faf4 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -197,8 +197,10 @@ def _unsupported_features_for_resources( # because `skypilot_config` may change for an existing cluster. # Clusters created with MIG (only GPU clusters) cannot be stopped. if (skypilot_config.get_nested( - ('gcp', 'managed_instance_group'), None) is not None and - resources.accelerators): + ('gcp', 'managed_instance_group'), + None, + override_configs=resources.cluster_config_overrides) is not None + and resources.accelerators): unsupported[clouds.CloudImplementationFeatures.STOP] = ( 'Managed Instance Group (MIG) does not support stopping yet.') unsupported[clouds.CloudImplementationFeatures.SPOT_INSTANCE] = ( @@ -506,7 +508,9 @@ def make_deploy_resources_variables( resources_vars['tpu_node_name'] = tpu_node_name managed_instance_group_config = skypilot_config.get_nested( - ('gcp', 'managed_instance_group'), None) + ('gcp', 'managed_instance_group'), + None, + override_configs=resources.cluster_config_overrides) use_mig = managed_instance_group_config is not None resources_vars['gcp_use_managed_instance_group'] = use_mig # Convert boolean to 0 or 1 in string, as GCP does not support boolean diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 1e307f475c8..78471e0de9f 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -38,15 +38,6 @@ class Kubernetes(clouds.Cloud): SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys' SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod' - # Timeout for resource provisioning. This timeout determines how long to - # wait for pod to be in pending status before giving up. - # Larger timeout may be required for autoscaling clusters, since autoscaler - # may take some time to provision new nodes. - # Note that this timeout includes time taken by the Kubernetes scheduler - # itself, which can be upto 2-3 seconds. - # For non-autoscaling clusters, we conservatively set this to 10s. - timeout = skypilot_config.get_nested(['kubernetes', 'provision_timeout'], - 10) # Limit the length of the cluster name to avoid exceeding the limit of 63 # characters for Kubernetes resources. We limit to 42 characters (63-21) to @@ -309,6 +300,17 @@ def make_deploy_resources_variables( if resources.use_spot: spot_label_key, spot_label_value = kubernetes_utils.get_spot_label() + # Timeout for resource provisioning. This timeout determines how long to + # wait for pod to be in pending status before giving up. + # Larger timeout may be required for autoscaling clusters, since + # autoscaler may take some time to provision new nodes. + # Note that this timeout includes time taken by the Kubernetes scheduler + # itself, which can be upto 2-3 seconds. + # For non-autoscaling clusters, we conservatively set this to 10s. + timeout = skypilot_config.get_nested( + ('kubernetes', 'provision_timeout'), + 10, + override_configs=resources.cluster_config_overrides) deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, @@ -316,7 +318,7 @@ def make_deploy_resources_variables( 'cpus': str(cpus), 'memory': str(mem), 'accelerator_count': str(acc_count), - 'timeout': str(self.timeout), + 'timeout': str(timeout), 'k8s_namespace': kubernetes_utils.get_current_kube_config_context_namespace(), 'k8s_port_mode': port_mode.value, diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 41b43b82c2c..80bc96ddb94 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -1367,9 +1367,10 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): elif isinstance(value, list) and key in destination: assert isinstance(destination[key], list), \ f'Expected {key} to be a list, found {destination[key]}' - if key == 'containers': - # If the key is 'containers', we take the first and only - # container in the list and merge it. + if key in ['containers', 'imagePullSecrets']: + # If the key is 'containers' or 'imagePullSecrets, we take the + # first and only container/secret in the list and merge it, as + # we only support one container per pod. assert len(value) == 1, \ f'Expected only one container, found {value}' merge_dicts(value[0], destination[key][0]) @@ -1392,7 +1393,10 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): destination[key] = value -def combine_pod_config_fields(cluster_yaml_path: str) -> None: +def combine_pod_config_fields( + cluster_yaml_path: str, + cluster_config_overrides: Dict[str, Any], +) -> None: """Adds or updates fields in the YAML with fields from the ~/.sky/config's kubernetes.pod_spec dict. This can be used to add fields to the YAML that are not supported by @@ -1434,8 +1438,14 @@ def combine_pod_config_fields(cluster_yaml_path: str) -> None: with open(cluster_yaml_path, 'r', encoding='utf-8') as f: yaml_content = f.read() yaml_obj = yaml.safe_load(yaml_content) + # We don't use override_configs in `skypilot_config.get_nested`, as merging + # the pod config requires special handling. kubernetes_config = skypilot_config.get_nested(('kubernetes', 'pod_config'), - {}) + default_value={}, + override_configs={}) + override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get( + 'pod_config', {})) + merge_dicts(override_pod_config, kubernetes_config) # Merge the kubernetes config into the YAML for both head and worker nodes. merge_dicts( @@ -1567,7 +1577,7 @@ def get_head_pod_name(cluster_name_on_cloud: str): def get_autoscaler_type( ) -> Optional[kubernetes_enums.KubernetesAutoscalerType]: """Returns the autoscaler type by reading from config""" - autoscaler_type = skypilot_config.get_nested(['kubernetes', 'autoscaler'], + autoscaler_type = skypilot_config.get_nested(('kubernetes', 'autoscaler'), None) if autoscaler_type is not None: autoscaler_type = kubernetes_enums.KubernetesAutoscalerType( diff --git a/sky/resources.py b/sky/resources.py index 252edff5da6..38f7a9784e6 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -44,7 +44,7 @@ class Resources: """ # If any fields changed, increment the version. For backward compatibility, # modify the __setstate__ method to handle the old version. - _VERSION = 18 + _VERSION = 19 def __init__( self, @@ -68,6 +68,7 @@ def __init__( _docker_login_config: Optional[docker_utils.DockerLoginConfig] = None, _is_image_managed: Optional[bool] = None, _requires_fuse: Optional[bool] = None, + _cluster_config_overrides: Optional[Dict[str, Any]] = None, ): """Initialize a Resources object. @@ -218,6 +219,8 @@ def __init__( self._requires_fuse = _requires_fuse + self._cluster_config_overrides = _cluster_config_overrides + self._set_cpus(cpus) self._set_memory(memory) self._set_accelerators(accelerators, accelerator_args) @@ -448,6 +451,12 @@ def requires_fuse(self) -> bool: return False return self._requires_fuse + @property + def cluster_config_overrides(self) -> Dict[str, Any]: + if self._cluster_config_overrides is None: + return {} + return self._cluster_config_overrides + @requires_fuse.setter def requires_fuse(self, value: Optional[bool]) -> None: self._requires_fuse = value @@ -1011,13 +1020,39 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, cloud.make_deploy_resources_variables() method, and the cloud-agnostic variables are generated by this method. """ + # Initial setup commands + initial_setup_commands = [] + if (skypilot_config.get_nested( + ('nvidia_gpus', 'disable_ecc'), + False, + override_configs=self.cluster_config_overrides) and + self.accelerators is not None): + initial_setup_commands = [constants.DISABLE_GPU_ECC_COMMAND] + + # Docker run options + docker_run_options = skypilot_config.get_nested( + ('docker', 'run_options'), + default_value=[], + override_configs=self.cluster_config_overrides) + if isinstance(docker_run_options, str): + docker_run_options = [docker_run_options] + if docker_run_options and isinstance(self.cloud, clouds.Kubernetes): + logger.warning( + f'{colorama.Style.DIM}Docker run options are specified, ' + 'but ignored for Kubernetes: ' + f'{" ".join(docker_run_options)}' + f'{colorama.Style.RESET_ALL}') + + docker_image = self.extract_docker_image() + + # Cloud specific variables cloud_specific_variables = self.cloud.make_deploy_resources_variables( self, cluster_name_on_cloud, region, zones, dryrun) - docker_image = self.extract_docker_image() return dict( cloud_specific_variables, **{ # Docker config + 'docker_run_options': docker_run_options, # Docker image. The image name used to pull the image, e.g. # ubuntu:latest. 'docker_image': docker_image, @@ -1027,7 +1062,9 @@ def make_deploy_variables(self, cluster_name_on_cloud: str, constants.DEFAULT_DOCKER_CONTAINER_NAME, # Docker login config (if any). This helps pull the image from # private registries. - 'docker_login_config': self._docker_login_config + 'docker_login_config': self._docker_login_config, + # Initial setup commands. + 'initial_setup_commands': initial_setup_commands, }) def get_reservations_available_resources(self) -> Dict[str, int]: @@ -1208,6 +1245,8 @@ def copy(self, **override) -> 'Resources': _is_image_managed=override.pop('_is_image_managed', self._is_image_managed), _requires_fuse=override.pop('_requires_fuse', self._requires_fuse), + _cluster_config_overrides=override.pop( + '_cluster_config_overrides', self._cluster_config_overrides), ) assert len(override) == 0 return resources @@ -1367,6 +1406,8 @@ def _from_yaml_config_single(cls, config: Dict[str, str]) -> 'Resources': resources_fields['_is_image_managed'] = config.pop( '_is_image_managed', None) resources_fields['_requires_fuse'] = config.pop('_requires_fuse', None) + resources_fields['_cluster_config_overrides'] = config.pop( + '_cluster_config_overrides', None) if resources_fields['cpus'] is not None: resources_fields['cpus'] = str(resources_fields['cpus']) @@ -1410,6 +1451,8 @@ def add_if_not_none(key, value): if self._docker_login_config is not None: config['_docker_login_config'] = dataclasses.asdict( self._docker_login_config) + add_if_not_none('_cluster_config_overrides', + self._cluster_config_overrides) if self._is_image_managed is not None: config['_is_image_managed'] = self._is_image_managed if self._requires_fuse is not None: @@ -1525,4 +1568,8 @@ def __setstate__(self, state): if version < 18: self._job_recovery = state.pop('_spot_recovery', None) + if version < 19: + self._cluster_config_overrides = state.pop( + '_cluster_config_overrides', None) + self.__dict__.update(state) diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index c456b48b306..359914b51f9 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -1,4 +1,6 @@ """Constants for SkyPilot.""" +from typing import List, Tuple + from packaging import version import sky @@ -261,3 +263,13 @@ # Placeholder for the SSH user in proxy command, replaced when the ssh_user is # known after provisioning. SKY_SSH_USER_PLACEHOLDER = 'skypilot:ssh_user' + +# The keys that can be overridden in the `~/.sky/config.yaml` file. The +# overrides are specified in task YAMLs. +OVERRIDEABLE_CONFIG_KEYS: List[Tuple[str, ...]] = [ + ('docker', 'run_options'), + ('nvidia_gpus', 'disable_ecc'), + ('kubernetes', 'pod_config'), + ('kubernetes', 'provision_timeout'), + ('gcp', 'managed_instance_group'), +] diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index 5b205e2692a..52e1d0ae3d9 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -1,7 +1,7 @@ """Immutable user configurations (EXPERIMENTAL). -On module import, we attempt to parse the config located at CONFIG_PATH. Caller -can then use +On module import, we attempt to parse the config located at CONFIG_PATH +(default: ~/.sky/config.yaml). Caller can then use >> skypilot_config.loaded() @@ -11,6 +11,13 @@ >> skypilot_config.get_nested(('auth', 'some_auth_config'), default_value) +The config can be overridden by the configs in task YAMLs. Callers are +responsible to provide the override_configs. If the nested key is part of +OVERRIDEABLE_CONFIG_KEYS, override_configs must be provided (can be empty): + + >> skypilot_config.get_nested(('docker', 'run_options'), default_value + override_configs={'docker': {'run_options': 'value'}}) + To set a value in the nested-key config: >> config_dict = skypilot_config.set_nested(('auth', 'some_key'), value) @@ -44,11 +51,12 @@ import copy import os import pprint -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Tuple import yaml from sky import sky_logging +from sky.skylet import constants from sky.utils import common_utils from sky.utils import schemas from sky.utils import ux_utils @@ -73,19 +81,15 @@ logger = sky_logging.init_logger(__name__) # The loaded config. -_dict = None +_dict: Optional[Dict[str, Any]] = None _loaded_config_path = None -def get_nested(keys: Iterable[str], default_value: Any) -> Any: - """Gets a nested key. - - If any key is not found, or any intermediate key does not point to a dict - value, returns 'default_value'. - """ - if _dict is None: +def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], + default_value: Any) -> Any: + if configs is None: return default_value - curr = _dict + curr = configs for key in keys: if isinstance(curr, dict) and key in curr: curr = curr[key] @@ -95,27 +99,73 @@ def get_nested(keys: Iterable[str], default_value: Any) -> Any: return curr -def set_nested(keys: Iterable[str], value: Any) -> Dict[str, Any]: +def get_nested(keys: Tuple[str, ...], + default_value: Any, + override_configs: Optional[Dict[str, Any]] = None) -> Any: + """Gets a nested key. + + If any key is not found, or any intermediate key does not point to a dict + value, returns 'default_value'. + + When 'keys' is within OVERRIDEABLE_CONFIG_KEYS, 'override_configs' must be + provided (can be empty). Otherwise, 'override_configs' must not be provided. + + Args: + keys: A tuple of strings representing the nested keys. + default_value: The default value to return if the key is not found. + override_configs: A dict of override configs with the same schema as + the config file, but only containing the keys to override. + + Returns: + The value of the nested key, or 'default_value' if not found. + """ + assert not ( + keys in constants.OVERRIDEABLE_CONFIG_KEYS and + override_configs is None), ( + f'Override configs must be provided when keys {keys} is within ' + 'constants.OVERRIDEABLE_CONFIG_KEYS: ' + f'{constants.OVERRIDEABLE_CONFIG_KEYS}') + assert not ( + keys not in constants.OVERRIDEABLE_CONFIG_KEYS and + override_configs is not None + ), (f'Override configs must not be provided when keys {keys} is not within ' + 'constants.OVERRIDEABLE_CONFIG_KEYS: ' + f'{constants.OVERRIDEABLE_CONFIG_KEYS}') + config: Dict[str, Any] = {} + if _dict is not None: + config = copy.deepcopy(_dict) + if override_configs is None: + override_configs = {} + config = _recursive_update(config, override_configs) + return _get_nested(config, keys, default_value) + + +def _recursive_update(base_config: Dict[str, Any], + override_config: Dict[str, Any]) -> Dict[str, Any]: + """Recursively updates base configuration with override configuration""" + for key, value in override_config.items(): + if (isinstance(value, dict) and key in base_config and + isinstance(base_config[key], dict)): + _recursive_update(base_config[key], value) + else: + base_config[key] = value + return base_config + + +def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: """Returns a deep-copied config with the nested key set to value. Like get_nested(), if any key is not found, this will not raise an error. """ _check_loaded_or_die() assert _dict is not None - curr = copy.deepcopy(_dict) - to_return = curr - prev = None - for i, key in enumerate(keys): - if key not in curr: - curr[key] = {} - prev = curr - curr = curr[key] - if i == len(keys) - 1: - prev_value = prev[key] - prev[key] = value - logger.debug(f'Set the value of {keys} to {value} (previous: ' - f'{prev_value}). Returning conf: {to_return}') - return to_return + override = {} + for i, key in enumerate(reversed(keys)): + if i == 0: + override = {key: value} + else: + override = {key: override} + return _recursive_update(copy.deepcopy(_dict), override) def to_dict() -> Dict[str, Any]: diff --git a/sky/task.py b/sky/task.py index 3dd254838f0..b11f1428cd3 100644 --- a/sky/task.py +++ b/sky/task.py @@ -456,8 +456,25 @@ def from_yaml_config( task.set_outputs(outputs=outputs, estimated_size_gigabytes=estimated_size_gigabytes) + # Experimental configs. + experimnetal_configs = config.pop('experimental', None) + cluster_config_override = None + if experimnetal_configs is not None: + cluster_config_override = experimnetal_configs.pop( + 'config_overrides', None) + logger.debug('Overriding skypilot config with task-level config: ' + f'{cluster_config_override}') + assert not experimnetal_configs, ('Invalid task args: ' + f'{experimnetal_configs.keys()}') + # Parse resources field. - resources_config = config.pop('resources', None) + resources_config = config.pop('resources', {}) + if cluster_config_override is not None: + assert resources_config.get('_cluster_config_overrides') is None, ( + 'Cannot set _cluster_config_overrides in both resources and ' + 'experimental.config_overrides') + resources_config[ + '_cluster_config_overrides'] = cluster_config_override task.set_resources(sky.Resources.from_yaml_config(resources_config)) service = config.pop('service', None) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index c6c6193c611..a529d61f2f6 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -4,6 +4,9 @@ https://json-schema.org/ """ import enum +from typing import Any, Dict, List, Tuple + +from sky.skylet import constants def _check_not_both_fields_present(field1: str, field2: str): @@ -145,7 +148,8 @@ def _get_single_resources_schema(): 'type': 'null', }] }, - # The following fields are for internal use only. + # The following fields are for internal use only. Should not be + # specified in the task config. '_docker_login_config': { 'type': 'object', 'required': ['username', 'password', 'server'], @@ -168,6 +172,9 @@ def _get_single_resources_schema(): '_requires_fuse': { 'type': 'boolean', }, + '_cluster_config_overrides': { + 'type': 'object', + }, } } @@ -370,6 +377,74 @@ def get_service_schema(): } +def _filter_schema(schema: dict, keys_to_keep: List[Tuple[str, ...]]) -> dict: + """Recursively filter a schema to include only certain keys. + + Args: + schema: The original schema dictionary. + keys_to_keep: List of tuples with the path of keys to retain. + + Returns: + The filtered schema. + """ + # Convert list of tuples to a dictionary for easier access + paths_dict: Dict[str, Any] = {} + for path in keys_to_keep: + current = paths_dict + for step in path: + if step not in current: + current[step] = {} + current = current[step] + + def keep_keys(current_schema: dict, current_path_dict: dict, + new_schema: dict) -> dict: + # Base case: if we reach a leaf in the path_dict, we stop. + if (not current_path_dict or not isinstance(current_schema, dict) or + not current_schema.get('properties')): + return current_schema + + if 'properties' not in new_schema: + new_schema = { + key: current_schema[key] + for key in current_schema + # We do not support the handling of `oneOf`, `anyOf`, `allOf`, + # `required` for now. + if key not in + {'properties', 'oneOf', 'anyOf', 'allOf', 'required'} + } + new_schema['properties'] = {} + for key, sub_schema in current_schema['properties'].items(): + if key in current_path_dict: + # Recursively keep keys if further path dict exists + new_schema['properties'][key] = {} + current_path_value = current_path_dict.pop(key) + new_schema['properties'][key] = keep_keys( + sub_schema, current_path_value, + new_schema['properties'][key]) + + return new_schema + + # Start the recursive filtering + new_schema = keep_keys(schema, paths_dict, {}) + assert not paths_dict, f'Unprocessed keys: {paths_dict}' + return new_schema + + +def _experimental_task_schema() -> dict: + config_override_schema = _filter_schema(get_config_schema(), + constants.OVERRIDEABLE_CONFIG_KEYS) + return { + 'experimental': { + 'type': 'object', + 'required': [], + 'additionalProperties': False, + 'properties': { + 'config_overrides': config_override_schema, + } + } + } + + def get_task_schema(): return { '$schema': 'https://json-schema.org/draft/2020-12/schema', @@ -435,6 +510,7 @@ def get_task_schema(): 'type': 'number' } }, + **_experimental_task_schema(), } } diff --git a/tests/conftest.py b/tests/conftest.py index ce92afd88c7..b4e025a8f2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,7 +199,7 @@ def generic_cloud(request) -> str: @pytest.fixture -def enable_all_clouds(monkeypatch: pytest.MonkeyPatch): +def enable_all_clouds(monkeypatch: pytest.MonkeyPatch) -> None: common.enable_all_clouds_in_monkeypatch(monkeypatch) diff --git a/tests/test_config.py b/tests/test_config.py index 44154d7348d..c01f06d6fca 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,7 +4,9 @@ import pytest +import sky from sky import skypilot_config +from sky.skylet import constants from sky.utils import common_utils from sky.utils import kubernetes_enums @@ -12,6 +14,9 @@ PROXY_COMMAND = 'ssh -W %h:%p -i ~/.ssh/id_rsa -o StrictHostKeyChecking=no' NODEPORT_MODE_NAME = kubernetes_enums.KubernetesNetworkingMode.NODEPORT.value PORT_FORWARD_MODE_NAME = kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD.value +RUN_DURATION = 30 +RUN_DURATION_OVERRIDE = 10 +PROVISION_TIMEOUT = 600 def _reload_config() -> None: @@ -31,7 +36,7 @@ def _check_empty_config() -> None: def _create_config_file(config_file_path: pathlib.Path) -> None: - config_file_path.open('w', encoding='utf-8').write( + config_file_path.write_text( textwrap.dedent(f"""\ aws: vpc_name: {VPC_NAME} @@ -41,12 +46,56 @@ def _create_config_file(config_file_path: pathlib.Path) -> None: gcp: vpc_name: {VPC_NAME} use_internal_ips: true + managed_instance_group: + run_duration: {RUN_DURATION} + provision_timeout: {PROVISION_TIMEOUT} kubernetes: networking: {NODEPORT_MODE_NAME} + pod_config: + spec: + metadata: + annotations: + my_annotation: my_value + runtimeClassName: nvidia # Custom runtimeClassName for GPU pods. + imagePullSecrets: + - name: my-secret # Pull images from a private registry using a secret + """)) +def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: + task_file_path.write_text( + textwrap.dedent(f"""\ + experimental: + config_overrides: + docker: + run_options: + - -v /tmp:/tmp + kubernetes: + pod_config: + metadata: + labels: + test-key: test-value + annotations: + abc: def + spec: + imagePullSecrets: + - name: my-secret-2 + provision_timeout: 100 + gcp: + managed_instance_group: + run_duration: {RUN_DURATION_OVERRIDE} + nvidia_gpus: + disable_ecc: true + resources: + image_id: docker:ubuntu:latest + + setup: echo 'Setting up...' + run: echo 'Running...' + """)) + + def test_no_config(monkeypatch) -> None: """Test that the config is not loaded if the config file does not exist.""" monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', '/tmp/does_not_exist') @@ -230,3 +279,98 @@ def test_config_with_env(monkeypatch, tmp_path) -> None: None) == PROXY_COMMAND assert skypilot_config.get_nested(('gcp', 'vpc_name'), None) == VPC_NAME assert skypilot_config.get_nested(('gcp', 'use_internal_ips'), None) + + +def test_k8s_config_with_override(monkeypatch, tmp_path, + enable_all_clouds) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + task_path = tmp_path / 'task.yaml' + _create_task_yaml_file(task_path) + task = sky.Task.from_yaml(task_path) + + # Test Kubernetes overrides + # Get cluster YAML + cluster_name = 'test-kubernetes-config-with-override' + task.set_resources_override({'cloud': sky.Kubernetes()}) + sky.launch(task, cluster_name=cluster_name, dryrun=True) + cluster_yaml = pathlib.Path( + f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename( + tmp_path / (cluster_name + '.yml')) + + # Load the cluster YAML + cluster_config = common_utils.read_yaml(cluster_yaml) + head_node_type = cluster_config['head_node_type'] + cluster_pod_config = cluster_config['available_node_types'][head_node_type][ + 'node_config'] + assert cluster_pod_config['metadata']['labels']['test-key'] == 'test-value' + assert cluster_pod_config['metadata']['labels']['parent'] == 'skypilot' + assert cluster_pod_config['metadata']['annotations']['abc'] == 'def' + assert len(cluster_pod_config['spec'] + ['imagePullSecrets']) == 1 and cluster_pod_config['spec'][ + 'imagePullSecrets'][0]['name'] == 'my-secret-2' + assert cluster_pod_config['spec']['runtimeClassName'] == 'nvidia' + + +def test_gcp_config_with_override(monkeypatch, tmp_path, + enable_all_clouds) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + task_path = tmp_path / 'task.yaml' + _create_task_yaml_file(task_path) + task = sky.Task.from_yaml(task_path) + + # Test GCP overrides + cluster_name = 'test-gcp-config-with-override' + task.set_resources_override({'cloud': sky.GCP(), 'accelerators': 'L4'}) + sky.launch(task, cluster_name=cluster_name, dryrun=True) + cluster_yaml = pathlib.Path( + f'~/.sky/generated/{cluster_name}.yml.tmp').expanduser().rename( + tmp_path / (cluster_name + '.yml')) + + # Load the cluster YAML + cluster_config = common_utils.read_yaml(cluster_yaml) + assert cluster_config['provider']['vpc_name'] == VPC_NAME + assert '-v /tmp:/tmp' in cluster_config['docker'][ + 'run_options'], cluster_config + assert constants.DISABLE_GPU_ECC_COMMAND in cluster_config[ + 'setup_commands'][0] + head_node_type = cluster_config['head_node_type'] + cluster_node_config = cluster_config['available_node_types'][ + head_node_type]['node_config'] + assert cluster_node_config['managed-instance-group'][ + 'run_duration'] == RUN_DURATION_OVERRIDE + assert cluster_node_config['managed-instance-group'][ + 'provision_timeout'] == PROVISION_TIMEOUT + + +def test_config_with_invalid_override(monkeypatch, tmp_path, + enable_all_clouds) -> None: + config_path = tmp_path / 'config.yaml' + _create_config_file(config_path) + monkeypatch.setattr(skypilot_config, 'CONFIG_PATH', config_path) + + _reload_config() + + task_config_yaml = textwrap.dedent(f"""\ + experimental: + config_overrides: + gcp: + vpc_name: abc + resources: + image_id: docker:ubuntu:latest + + setup: echo 'Setting up...' + run: echo 'Running...' + """) + + with pytest.raises(ValueError, match='Found unsupported') as e: + task_path = tmp_path / 'task.yaml' + task_path.write_text(task_config_yaml) + sky.Task.from_yaml(task_path)