diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index bface9232cf..b3e8bcf2f9d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -10,6 +10,7 @@ on: branches: - master - 'releases/**' + - restapi merge_group: jobs: @@ -19,18 +20,18 @@ jobs: python-version: [3.8] test-path: - tests/unit_tests - - tests/test_api.py - - tests/test_cli.py - - tests/test_config.py - - tests/test_global_user_state.py - - tests/test_jobs.py - - tests/test_list_accelerators.py - - tests/test_optimizer_dryruns.py - - tests/test_optimizer_random_dag.py - - tests/test_storage.py - - tests/test_wheels.py - - tests/test_jobs_and_serve.py - - tests/test_yaml_parser.py + # - tests/test_api.py + # - tests/test_cli.py + # - tests/test_config.py + # - tests/test_global_user_state.py + # - tests/test_jobs.py + # - tests/test_list_accelerators.py + # - tests/test_optimizer_dryruns.py + # - tests/test_optimizer_random_dag.py + # - tests/test_storage.py + # - tests/test_wheels.py + # - tests/test_jobs_and_serve.py + # - tests/test_yaml_parser.py runs-on: ubuntu-latest steps: - name: Checkout repository diff --git a/docs/source/cloud-setup/policy.rst b/docs/source/cloud-setup/policy.rst index 663b2db7f28..a42f8b574be 100644 --- a/docs/source/cloud-setup/policy.rst +++ b/docs/source/cloud-setup/policy.rst @@ -102,10 +102,10 @@ a request should be rejected, the policy should raise an exception. The ``sky.Config`` and ``sky.RequestOptions`` classes are defined as follows: -.. literalinclude:: ../../../sky/skypilot_config.py +.. literalinclude:: ../../../sky/utils/config_utils.py :language: python :pyobject: Config - :caption: `Config Class `_ + :caption: `Config Class `_ .. literalinclude:: ../../../sky/admin_policy.py diff --git a/sky/__init__.py b/sky/__init__.py index d2977359e59..d8f0b9b89d7 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -119,9 +119,9 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.optimizer import Optimizer from sky.resources import Resources from sky.skylet.job_lib import JobStatus -from sky.skypilot_config import Config from sky.task import Task from sky.utils.common import OptimizeTarget +from sky.utils.config_utils import Config from sky.utils.status_lib import ClusterStatus # Aliases. diff --git a/sky/api/requests/executor.py b/sky/api/requests/executor.py index 0993361707b..0238b844691 100644 --- a/sky/api/requests/executor.py +++ b/sky/api/requests/executor.py @@ -13,6 +13,7 @@ from sky import global_user_state from sky import models from sky import sky_logging +from sky import skypilot_config from sky.api.requests import payloads from sky.api.requests import requests from sky.api.requests.queues import mp_queue @@ -163,6 +164,7 @@ def restore_output(original_stdout, original_stderr): # Store copies of the original stdout and stderr file descriptors original_stdout, original_stderr = redirect_output(f) try: + os.environ.update(request_body.env_vars) user = models.User( id=request_body.env_vars[constants.USER_ID_ENV_VAR], @@ -172,7 +174,9 @@ def restore_output(original_stdout, original_stderr): # Force color to be enabled. os.environ['CLICOLOR_FORCE'] = '1' common.reload() - return_value = func(**request_body.to_kwargs()) + with skypilot_config.override_skypilot_config( + request_body.override_skypilot_config): + return_value = func(**request_body.to_kwargs()) except Exception as e: # pylint: disable=broad-except with ux_utils.enable_traceback(): stacktrace = traceback.format_exc() diff --git a/sky/api/requests/payloads.py b/sky/api/requests/payloads.py index d5018eb32df..5f637d12303 100644 --- a/sky/api/requests/payloads.py +++ b/sky/api/requests/payloads.py @@ -8,6 +8,7 @@ import pydantic from sky import serve +from sky import skypilot_config from sky.api import common from sky.skylet import constants from sky.utils import common as common_lib @@ -19,17 +20,31 @@ def request_body_env_vars() -> dict: env_vars = {} for env_var in os.environ: - if env_var.startswith('SKYPILOT_'): + if env_var.startswith(constants.SKYPILOT_ENV_VAR_PREFIX): env_vars[env_var] = os.environ[env_var] env_vars[constants.USER_ID_ENV_VAR] = common_utils.get_user_hash() env_vars[constants.USER_ENV_VAR] = getpass.getuser() + # Remove the path to config file, as the config content is included in the + # request body and will be merged with the config on the server side. + env_vars.pop(skypilot_config.ENV_VAR_SKYPILOT_CONFIG, None) return env_vars +def get_override_skypilot_config_from_client() -> Dict[str, Any]: + """Returns the override configs from the client.""" + config = skypilot_config.to_dict() + # Remove the API server config, as we should not specify the SkyPilot + # server endpoint on the server side. + config.pop('api_server', None) + return config + + class RequestBody(pydantic.BaseModel): """The request body for the SkyPilot API.""" env_vars: Dict[str, str] = request_body_env_vars() entrypoint_command: str = common_utils.get_pretty_entry_point() + override_skypilot_config: Optional[Dict[ + str, Any]] = get_override_skypilot_config_from_client() def to_kwargs(self) -> Dict[str, Any]: """Convert the request body to a kwargs dictionary on API server. @@ -40,6 +55,7 @@ def to_kwargs(self) -> Dict[str, Any]: kwargs = self.model_dump() kwargs.pop('env_vars') kwargs.pop('entrypoint_command') + kwargs.pop('override_skypilot_config') return kwargs diff --git a/sky/authentication.py b/sky/authentication.py index 499d276edb0..898b7264951 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -46,6 +46,7 @@ from sky.provision.kubernetes import utils as kubernetes_utils from sky.provision.lambda_cloud import lambda_utils from sky.utils import common_utils +from sky.utils import config_utils from sky.utils import kubernetes_enums from sky.utils import subprocess_utils from sky.utils import ux_utils @@ -402,7 +403,7 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: } custom_metadata = skypilot_config.get_nested( ('kubernetes', 'custom_metadata'), {}) - kubernetes_utils.merge_dicts(custom_metadata, secret_metadata) + config_utils.merge_k8s_configs(secret_metadata, custom_metadata) secret = k8s.client.V1Secret( metadata=k8s.client.V1ObjectMeta(**secret_metadata), diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index d8fac3bf638..64054d0e362 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -25,6 +25,7 @@ from sky.provision.kubernetes import network_utils from sky.skylet import constants from sky.utils import common_utils +from sky.utils import config_utils from sky.utils import env_options from sky.utils import kubernetes_enums from sky.utils import schemas @@ -1675,50 +1676,6 @@ def get_endpoint_debug_message() -> str: debug_cmd=debug_cmd) -def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): - """Merge two dictionaries into the destination dictionary. - - Updates nested dictionaries instead of replacing them. - If a list is encountered, it will be appended to the destination list. - - An exception is when the key is 'containers', in which case the - first container in the list will be fetched and merge_dict will be - called on it with the first container in the destination list. - """ - for key, value in source.items(): - if isinstance(value, dict) and key in destination: - merge_dicts(value, destination[key]) - elif isinstance(value, list) and key in destination: - assert isinstance(destination[key], list), \ - f'Expected {key} to be a list, found {destination[key]}' - if key in ['containers', 'imagePullSecrets']: - # If the key is 'containers' or 'imagePullSecrets, we take the - # first and only container/secret in the list and merge it, as - # we only support one container per pod. - assert len(value) == 1, \ - f'Expected only one container, found {value}' - merge_dicts(value[0], destination[key][0]) - elif key in ['volumes', 'volumeMounts']: - # If the key is 'volumes' or 'volumeMounts', we search for - # item with the same name and merge it. - for new_volume in value: - new_volume_name = new_volume.get('name') - if new_volume_name is not None: - destination_volume = next( - (v for v in destination[key] - if v.get('name') == new_volume_name), None) - if destination_volume is not None: - merge_dicts(new_volume, destination_volume) - else: - destination[key].append(new_volume) - else: - destination[key].extend(value) - else: - if destination is None: - destination = {} - destination[key] = value - - def combine_pod_config_fields( cluster_yaml_path: str, cluster_config_overrides: Dict[str, Any], @@ -1771,12 +1728,12 @@ def combine_pod_config_fields( override_configs={}) override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get( 'pod_config', {})) - merge_dicts(override_pod_config, kubernetes_config) + config_utils.merge_k8s_configs(kubernetes_config, override_pod_config) # Merge the kubernetes config into the YAML for both head and worker nodes. - merge_dicts( - kubernetes_config, - yaml_obj['available_node_types']['ray_head_default']['node_config']) + config_utils.merge_k8s_configs( + yaml_obj['available_node_types']['ray_head_default']['node_config'], + kubernetes_config) # Write the updated YAML back to the file common_utils.dump_yaml(cluster_yaml_path, yaml_obj) @@ -1810,7 +1767,7 @@ def combine_metadata_fields(cluster_yaml_path: str) -> None: ] for destination in combination_destinations: - merge_dicts(custom_metadata, destination) + config_utils.merge_k8s_configs(destination, custom_metadata) # Write the updated YAML back to the file common_utils.dump_yaml(cluster_yaml_path, yaml_obj) @@ -1823,7 +1780,7 @@ def merge_custom_metadata(original_metadata: Dict[str, Any]) -> None: """ custom_metadata = skypilot_config.get_nested( ('kubernetes', 'custom_metadata'), {}) - merge_dicts(custom_metadata, original_metadata) + config_utils.merge_k8s_configs(original_metadata, custom_metadata) def check_nvidia_runtime_class(context: Optional[str] = None) -> bool: diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 99820e63388..82ec5d957c4 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -56,17 +56,20 @@ 'export PATH=' f'$(echo $PATH | sed "s|$(echo ~)/{SKY_REMOTE_PYTHON_ENV_NAME}/bin:||")') +# Prefix for SkyPilot environment variables +SKYPILOT_ENV_VAR_PREFIX = 'SKYPILOT_' + # The name for the environment variable that stores the unique ID of the # current task. This will stay the same across multiple recoveries of the # same managed task. -TASK_ID_ENV_VAR = 'SKYPILOT_TASK_ID' +TASK_ID_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}TASK_ID' # This environment variable stores a '\n'-separated list of task IDs that # are within the same managed job (DAG). This can be used by the user to # retrieve the task IDs of any tasks that are within the same managed job. # This environment variable is pre-assigned before any task starts # running within the same job, and will remain constant throughout the # lifetime of the job. -TASK_ID_LIST_ENV_VAR = 'SKYPILOT_TASK_IDS' +TASK_ID_LIST_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}TASK_IDS' # The version of skylet. MUST bump this version whenever we need the skylet to # be restarted on existing clusters updated with the new version of SkyPilot, @@ -90,9 +93,9 @@ # Docker default options DEFAULT_DOCKER_CONTAINER_NAME = 'sky_container' DEFAULT_DOCKER_PORT = 10022 -DOCKER_USERNAME_ENV_VAR = 'SKYPILOT_DOCKER_USERNAME' -DOCKER_PASSWORD_ENV_VAR = 'SKYPILOT_DOCKER_PASSWORD' -DOCKER_SERVER_ENV_VAR = 'SKYPILOT_DOCKER_SERVER' +DOCKER_USERNAME_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}DOCKER_USERNAME' +DOCKER_PASSWORD_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}DOCKER_PASSWORD' +DOCKER_SERVER_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}DOCKER_SERVER' DOCKER_LOGIN_ENV_VARS = { DOCKER_USERNAME_ENV_VAR, DOCKER_PASSWORD_ENV_VAR, @@ -229,12 +232,12 @@ # is mainly used to make sure sky commands runs on a VM launched by SkyPilot # will be recognized as the same user (e.g., jobs controller or sky serve # controller). -USER_ID_ENV_VAR = 'SKYPILOT_USER_ID' +USER_ID_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}USER_ID' # The name for the environment variable that stores SkyPilot user name. # Similar to USER_ID_ENV_VAR, this is mainly used to make sure sky commands # runs on a VM launched by SkyPilot will be recognized as the same user. -USER_ENV_VAR = 'SKYPILOT_USER' +USER_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}USER' # In most clouds, cluster names can only contain lowercase letters, numbers # and hyphens. We use this regex to validate the cluster name. @@ -269,13 +272,13 @@ # The name for the environment variable that stores the URL of the SkyPilot # API server. -SKY_API_SERVER_URL_ENV_VAR = 'SKYPILOT_API_SERVER_ENDPOINT' +SKY_API_SERVER_URL_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}API_SERVER_ENDPOINT' # SkyPilot environment variables -SKYPILOT_NUM_NODES = 'SKYPILOT_NUM_NODES' -SKYPILOT_NODE_IPS = 'SKYPILOT_NODE_IPS' -SKYPILOT_NUM_GPUS_PER_NODE = 'SKYPILOT_NUM_GPUS_PER_NODE' -SKYPILOT_NODE_RANK = 'SKYPILOT_NODE_RANK' +SKYPILOT_NUM_NODES = f'{SKYPILOT_ENV_VAR_PREFIX}NUM_NODES' +SKYPILOT_NODE_IPS = f'{SKYPILOT_ENV_VAR_PREFIX}NODE_IPS' +SKYPILOT_NUM_GPUS_PER_NODE = f'{SKYPILOT_ENV_VAR_PREFIX}NUM_GPUS_PER_NODE' +SKYPILOT_NODE_RANK = f'{SKYPILOT_ENV_VAR_PREFIX}NODE_RANK' # Placeholder for the SSH user in proxy command, replaced when the ssh_user is # known after provisioning. @@ -283,13 +286,17 @@ # The keys that can be overridden in the `~/.sky/config.yaml` file. The # overrides are specified in task YAMLs. -OVERRIDEABLE_CONFIG_KEYS: List[Tuple[str, ...]] = [ +OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [ ('docker', 'run_options'), ('nvidia_gpus', 'disable_ecc'), ('kubernetes', 'pod_config'), ('kubernetes', 'provision_timeout'), ('gcp', 'managed_instance_group'), ] +DISALLOWED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [ + ('admin_policy',), + ('api_server',), +] # Constants for Azure blob storage WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60 diff --git a/sky/skypilot_config.py b/sky/skypilot_config.py index aae62afc616..71f6b1834df 100644 --- a/sky/skypilot_config.py +++ b/sky/skypilot_config.py @@ -48,16 +48,19 @@ skypilot_config.get_nested(('a', 'nonexist'), None) # ==> None skypilot_config.get_nested(('a',), None) # ==> None """ +import contextlib import copy import os import pprint -from typing import Any, Dict, Iterable, Optional, Tuple +import tempfile +from typing import Any, Dict, Iterator, Optional, Tuple import yaml from sky import sky_logging from sky.skylet import constants from sky.utils import common_utils +from sky.utils import config_utils from sky.utils import schemas from sky.utils import ux_utils @@ -75,78 +78,16 @@ # (Used internally) An env var holding the path to the local config file. This # is only used by jobs controller tasks to ensure recoveries of the same job # use the same config file. -ENV_VAR_SKYPILOT_CONFIG = 'SKYPILOT_CONFIG' +ENV_VAR_SKYPILOT_CONFIG = f'{constants.SKYPILOT_ENV_VAR_PREFIX}CONFIG' # Path to the local config file. CONFIG_PATH = '~/.sky/config.yaml' - -class Config(Dict[str, Any]): - """SkyPilot config that supports setting/getting values with nested keys.""" - - def get_nested(self, - keys: Tuple[str, ...], - default_value: Any, - override_configs: Optional[Dict[str, Any]] = None) -> Any: - """Gets a nested key. - - If any key is not found, or any intermediate key does not point to a - dict value, returns 'default_value'. - - Args: - keys: A tuple of strings representing the nested keys. - default_value: The default value to return if the key is not found. - override_configs: A dict of override configs with the same schema as - the config file, but only containing the keys to override. - - Returns: - The value of the nested key, or 'default_value' if not found. - """ - config = copy.deepcopy(self) - if override_configs is not None: - config = _recursive_update(config, override_configs) - return _get_nested(config, keys, default_value) - - def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: - """In-place sets a nested key to value. - - Like get_nested(), if any key is not found, this will not raise an - error. - """ - override = {} - for i, key in enumerate(reversed(keys)): - if i == 0: - override = {key: value} - else: - override = {key: override} - _recursive_update(self, override) - - @classmethod - def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config': - if config is None: - return cls() - return cls(**config) - - # The loaded config. -_dict = Config() +_dict = config_utils.Config() _loaded_config_path: Optional[str] = None -def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], - default_value: Any) -> Any: - if configs is None: - return default_value - curr = configs - for key in keys: - if isinstance(curr, dict) and key in curr: - curr = curr[key] - else: - return default_value - logger.debug(f'User config: {".".join(keys)} -> {curr}') - return curr - - def get_nested(keys: Tuple[str, ...], default_value: Any, override_configs: Optional[Dict[str, Any]] = None) -> Any: @@ -167,31 +108,12 @@ def get_nested(keys: Tuple[str, ...], Returns: The value of the nested key, or 'default_value' if not found. """ - assert not ( - keys in constants.OVERRIDEABLE_CONFIG_KEYS and - override_configs is None), ( - f'Override configs must be provided when keys {keys} is within ' - 'constants.OVERRIDEABLE_CONFIG_KEYS: ' - f'{constants.OVERRIDEABLE_CONFIG_KEYS}') - assert not ( - keys not in constants.OVERRIDEABLE_CONFIG_KEYS and - override_configs is not None - ), (f'Override configs must not be provided when keys {keys} is not within ' - 'constants.OVERRIDEABLE_CONFIG_KEYS: ' - f'{constants.OVERRIDEABLE_CONFIG_KEYS}') - return _dict.get_nested(keys, default_value, override_configs) - - -def _recursive_update(base_config: Config, - override_config: Dict[str, Any]) -> Config: - """Recursively updates base configuration with override configuration""" - for key, value in override_config.items(): - if (isinstance(value, dict) and key in base_config and - isinstance(base_config[key], dict)): - _recursive_update(base_config[key], value) - else: - base_config[key] = value - return base_config + return _dict.get_nested( + keys, + default_value, + override_configs, + allowed_override_keys=constants.OVERRIDEABLE_CONFIG_KEYS_IN_TASK, + disallowed_override_keys=None) def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: @@ -204,7 +126,7 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]: return dict(**copied_dict) -def to_dict() -> Config: +def to_dict() -> config_utils.Config: """Returns a deep-copied version of the current config.""" return copy.deepcopy(_dict) @@ -228,7 +150,7 @@ def _try_load_config() -> None: logger.debug(f'Using config path: {config_path}') try: config = common_utils.read_yaml(config_path) - _dict = Config.from_dict(config) + _dict = config_utils.Config.from_dict(config) _loaded_config_path = config_path logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}') except yaml.YAMLError as e: @@ -257,3 +179,33 @@ def loaded_config_path() -> Optional[str]: def loaded() -> bool: """Returns if the user configurations are loaded.""" return bool(_dict) + + +@contextlib.contextmanager +def override_skypilot_config( + override_configs: Dict[str, Any]) -> Iterator[None]: + """Overrides the user configurations.""" + # TODO(zhwu): allow admin user to extend the disallowed keys or specify + # allowed keys. + config = _dict.get_nested( + keys=tuple(), + default_value=None, + override_configs=override_configs, + allowed_override_keys=None, + disallowed_override_keys=constants.DISALLOWED_CLIENT_OVERRIDE_KEYS) + try: + with tempfile.NamedTemporaryFile(mode='w', + prefix='skypilot_config', + delete=False) as f: + common_utils.dump_yaml(f.name, dict(config)) + f.flush() + f.close() + os.environ[ENV_VAR_SKYPILOT_CONFIG] = f.name + _try_load_config() + yield + finally: + os.environ.pop(ENV_VAR_SKYPILOT_CONFIG) + try: + os.remove(f.name) + except Exception: # pylint: disable=broad-except + pass diff --git a/sky/utils/admin_policy_utils.py b/sky/utils/admin_policy_utils.py index a83283f7dd6..7b754ed7051 100644 --- a/sky/utils/admin_policy_utils.py +++ b/sky/utils/admin_policy_utils.py @@ -14,6 +14,7 @@ from sky import skypilot_config from sky import task as task_lib from sky.utils import common_utils +from sky.utils import config_utils from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -55,7 +56,7 @@ def apply( entrypoint: Union['dag_lib.Dag', 'task_lib.Task'], use_mutated_config_in_current_request: bool = True, request_options: Optional[admin_policy.RequestOptions] = None, -) -> Tuple['dag_lib.Dag', skypilot_config.Config]: +) -> Tuple['dag_lib.Dag', config_utils.Config]: """Applies an admin policy (if registered) to a DAG or a task. It mutates a Dag by applying any registered admin policy and also diff --git a/sky/utils/common.py b/sky/utils/common.py index b33e42d5a4e..8a8b27ec14c 100644 --- a/sky/utils/common.py +++ b/sky/utils/common.py @@ -37,7 +37,6 @@ def reload(): # env vars, but since controller_utils is imported before the env vars are # set, it doesn't get updated. So we need to reload it here. # pylint: disable=import-outside-toplevel - from sky import skypilot_config from sky.utils import controller_utils global SKY_SERVE_CONTROLLER_NAME global JOB_CONTROLLER_NAME @@ -46,7 +45,6 @@ def reload(): JOB_CONTROLLER_NAME = ( f'{JOB_CONTROLLER_PREFIX}{common_utils.get_user_hash()}') importlib.reload(controller_utils) - importlib.reload(skypilot_config) # Make sure the logger takes the new environment variables. This is # necessary because the logger is initialized before the environment diff --git a/sky/utils/config_utils.py b/sky/utils/config_utils.py new file mode 100644 index 00000000000..91b9d038cba --- /dev/null +++ b/sky/utils/config_utils.py @@ -0,0 +1,194 @@ +"""Utilities for nested config.""" +import copy +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from sky import sky_logging + +logger = sky_logging.init_logger(__name__) + + +class Config(Dict[str, Any]): + """SkyPilot config that supports setting/getting values with nested keys.""" + + def get_nested( + self, + keys: Tuple[str, ...], + default_value: Any, + override_configs: Optional[Dict[str, Any]] = None, + allowed_override_keys: Optional[List[Tuple[str, ...]]] = None, + disallowed_override_keys: Optional[List[Tuple[str, + ...]]] = None) -> Any: + """Gets a nested key. + + If any key is not found, or any intermediate key does not point to a + dict value, returns 'default_value'. + + Args: + keys: A tuple of strings representing the nested keys. + default_value: The default value to return if the key is not found. + override_configs: A dict of override configs with the same schema as + the config file, but only containing the keys to override. + + Returns: + The value of the nested key, or 'default_value' if not found. + """ + config = copy.deepcopy(self) + if override_configs is not None: + config = _recursive_update(config, override_configs, + allowed_override_keys, + disallowed_override_keys) + return _get_nested(config, keys, default_value) + + def set_nested(self, keys: Tuple[str, ...], value: Any) -> None: + """In-place sets a nested key to value. + + Like get_nested(), if any key is not found, this will not raise an + error. + """ + override = {} + for i, key in enumerate(reversed(keys)): + if i == 0: + override = {key: value} + else: + override = {key: override} + _recursive_update(self, override) + + @classmethod + def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config': + if config is None: + return cls() + return cls(**config) + + +def _check_allowed_and_disallowed_override_keys( + key: str, + allowed_override_keys: Optional[List[Tuple[str, ...]]] = None, + disallowed_override_keys: Optional[List[Tuple[str, ...]]] = None +) -> Tuple[Optional[List[Tuple[str, ...]]], Optional[List[Tuple[str, ...]]]]: + allowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = [] + disallowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = [] + if allowed_override_keys is not None: + for nested_key in allowed_override_keys: + if key == nested_key[0]: + if len(nested_key) == 1: + # Allowed key is fully matched, no need to check further. + allowed_keys_with_matched_prefix = None + break + assert allowed_keys_with_matched_prefix is not None + allowed_keys_with_matched_prefix.append(nested_key[1:]) + if (allowed_keys_with_matched_prefix is not None and + not allowed_keys_with_matched_prefix): + raise ValueError(f'Key {key} is not in allowed override keys: ' + f'{allowed_override_keys}') + else: + allowed_keys_with_matched_prefix = None + + if disallowed_override_keys is not None: + for nested_key in disallowed_override_keys: + if key == nested_key[0]: + if len(nested_key) == 1: + raise ValueError( + f'Key {key} is in disallowed override keys: ' + f'{disallowed_override_keys}') + assert disallowed_keys_with_matched_prefix is not None + disallowed_keys_with_matched_prefix.append(nested_key[1:]) + else: + disallowed_keys_with_matched_prefix = None + return allowed_keys_with_matched_prefix, disallowed_keys_with_matched_prefix + + +def _recursive_update( + base_config: Config, + override_config: Dict[str, Any], + allowed_override_keys: Optional[List[Tuple[str, ...]]] = None, + disallowed_override_keys: Optional[List[Tuple[str, + ...]]] = None) -> Config: + """Recursively updates base configuration with override configuration""" + for key, value in override_config.items(): + (next_allowed_override_keys, next_disallowed_override_keys + ) = _check_allowed_and_disallowed_override_keys( + key, allowed_override_keys, disallowed_override_keys) + if key == 'kubernetes' and key in base_config: + merge_k8s_configs(base_config[key], value, + next_allowed_override_keys, + next_disallowed_override_keys) + elif (isinstance(value, dict) and key in base_config and + isinstance(base_config[key], dict)): + _recursive_update(base_config[key], value, + next_allowed_override_keys, + next_disallowed_override_keys) + else: + base_config[key] = value + return base_config + + +def _get_nested(configs: Optional[Dict[str, Any]], keys: Iterable[str], + default_value: Any) -> Any: + if configs is None: + return default_value + curr = configs + for key in keys: + if isinstance(curr, dict) and key in curr: + curr = curr[key] + else: + return default_value + logger.debug(f'User config: {".".join(keys)} -> {curr}') + return curr + + +def merge_k8s_configs( + base_config: Dict[Any, Any], + override_config: Dict[Any, Any], + allowed_override_keys: Optional[List[Tuple[str, ...]]] = None, + disallowed_override_keys: Optional[List[Tuple[str, + ...]]] = None) -> None: + """Merge two configs into the base_config. + + Updates nested dictionaries instead of replacing them. + If a list is encountered, it will be appended to the base_config list. + + An exception is when the key is 'containers', in which case the + first container in the list will be fetched and merge_dict will be + called on it with the first container in the base_config list. + """ + for key, value in override_config.items(): + (next_allowed_override_keys, next_disallowed_override_keys + ) = _check_allowed_and_disallowed_override_keys( + key, allowed_override_keys, disallowed_override_keys) + if isinstance(value, dict) and key in base_config: + merge_k8s_configs(base_config[key], value, + next_allowed_override_keys, + next_disallowed_override_keys) + elif isinstance(value, list) and key in base_config: + assert isinstance(base_config[key], list), \ + f'Expected {key} to be a list, found {base_config[key]}' + if key in ['containers', 'imagePullSecrets']: + # If the key is 'containers' or 'imagePullSecrets, we take the + # first and only container/secret in the list and merge it, as + # we only support one container per pod. + assert len(value) == 1, \ + f'Expected only one container, found {value}' + merge_k8s_configs(base_config[key][0], value[0], + next_allowed_override_keys, + next_disallowed_override_keys) + elif key in ['volumes', 'volumeMounts']: + # If the key is 'volumes' or 'volumeMounts', we search for + # item with the same name and merge it. + for new_volume in value: + new_volume_name = new_volume.get('name') + if new_volume_name is not None: + destination_volume = next( + (v for v in base_config[key] + if v.get('name') == new_volume_name), None) + if destination_volume is not None: + merge_k8s_configs(destination_volume, new_volume) + else: + base_config[key].append(new_volume) + else: + base_config[key].extend(value) + else: + if base_config is None: + # TODO(zhwu): This will not be effective as the base_config + # is not returned by this function. + base_config = {} + base_config[key] = value diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index df6852dec1e..4ec9950cc90 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -27,6 +27,7 @@ from sky.skylet import log_lib from sky.utils import common from sky.utils import common_utils +from sky.utils import config_utils from sky.utils import env_options from sky.utils import registry from sky.utils import rich_utils @@ -607,7 +608,7 @@ def get_controller_resources( def _setup_proxy_command_on_controller( controller_launched_cloud: 'clouds.Cloud', - user_config: Dict[str, Any]) -> skypilot_config.Config: + user_config: Dict[str, Any]) -> config_utils.Config: """Sets up proxy command on the controller. This function should be called on the controller (remote cluster), which @@ -641,7 +642,7 @@ def _setup_proxy_command_on_controller( # (or name). It may not be a sufficient check (as it's always # possible that peering is not set up), but it may catch some # obvious errors. - config = skypilot_config.Config.from_dict(user_config) + config = config_utils.Config.from_dict(user_config) proxy_command_key = (str(controller_launched_cloud).lower(), 'ssh_proxy_command') ssh_proxy_command = config.get_nested(proxy_command_key, None) diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 7d716134272..b3d1748612a 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -448,8 +448,8 @@ def keep_keys(current_schema: dict, current_path_dict: dict, def _experimental_task_schema() -> dict: - config_override_schema = _filter_schema(get_config_schema(), - constants.OVERRIDEABLE_CONFIG_KEYS) + config_override_schema = _filter_schema( + get_config_schema(), constants.OVERRIDEABLE_CONFIG_KEYS_IN_TASK) return { 'experimental': { 'type': 'object', diff --git a/tests/test_config.py b/tests/test_config.py index 5789214dc61..12fb8e29668 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,6 +9,7 @@ from sky import skypilot_config from sky.skylet import constants from sky.utils import common_utils +from sky.utils import config_utils from sky.utils import kubernetes_enums DISK_ENCRYPTED = True @@ -22,7 +23,7 @@ def _reload_config() -> None: - skypilot_config._dict = skypilot_config.Config() + skypilot_config._dict = config_utils.Config() skypilot_config._loaded_config_path = None skypilot_config._try_load_config() @@ -101,7 +102,7 @@ def _create_task_yaml_file(task_file_path: pathlib.Path) -> None: def test_nested_config(monkeypatch) -> None: """Test that the nested config works.""" - config = skypilot_config.Config() + config = config_utils.Config() config.set_nested(('aws', 'ssh_proxy_command'), 'value') assert config == {'aws': {'ssh_proxy_command': 'value'}} diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index c9e7ad35af2..3462778c133 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -11,6 +11,7 @@ from sky import sky_logging from sky import skypilot_config from sky.utils import admin_policy_utils +from sky.utils import config_utils logger = sky_logging.init_logger(__name__) @@ -37,7 +38,7 @@ def _load_task_and_apply_policy( task: sky.Task, config_path: str, idle_minutes_to_autostop: Optional[int] = None, -) -> Tuple[sky.Dag, skypilot_config.Config]: +) -> Tuple[sky.Dag, config_utils.Config]: os.environ['SKYPILOT_CONFIG'] = config_path importlib.reload(skypilot_config) return admin_policy_utils.apply( diff --git a/tests/unit_tests/test_config_utils.py b/tests/unit_tests/test_config_utils.py new file mode 100644 index 00000000000..b419c7d05bb --- /dev/null +++ b/tests/unit_tests/test_config_utils.py @@ -0,0 +1,356 @@ +import pytest + +from sky.utils import config_utils + + +def test_merge_k8s_configs_with_container_resources(): + """Test merging Kubernetes configs with container resource specifications.""" + base_config = { + 'containers': [{ + 'resources': { + 'limits': { + 'cpu': '1', + 'memory': '1Gi' + }, + 'requests': { + 'cpu': '0.5' + } + } + }] + } + override_config = { + 'containers': [{ + 'resources': { + 'limits': { + 'memory': '2Gi' + }, + 'requests': { + 'memory': '1Gi' + } + } + }] + } + + config_utils.merge_k8s_configs(base_config, override_config) + container = base_config['containers'][0] + assert container['resources']['limits'] == {'cpu': '1', 'memory': '2Gi'} + assert container['resources']['requests'] == {'cpu': '0.5', 'memory': '1Gi'} + + +def test_merge_k8s_configs_with_deeper_override(): + base_config = { + 'containers': [{ + 'resources': { + 'limits': { + 'cpu': '1', + 'memory': '1Gi' + }, + } + }] + } + override_config = { + 'containers': [{ + 'resources': { + 'limits': { + 'memory': '2Gi' + }, + 'requests': { + 'memory': '1Gi' + } + } + }] + } + + config_utils.merge_k8s_configs(base_config, override_config) + container = base_config['containers'][0] + assert container['resources']['limits'] == {'cpu': '1', 'memory': '2Gi'} + assert container['resources']['requests'] == {'memory': '1Gi'} + + +def test_config_nested_empty_intermediate(): + """Test setting nested config with empty intermediate dictionaries.""" + config = config_utils.Config() + + # Set deeply nested value with no existing intermediate dicts + config.set_nested(('a', 'b', 'c', 'd'), 'value') + assert config.get_nested(('a', 'b', 'c', 'd'), None) == 'value' + + # Verify intermediate dictionaries were created + assert isinstance(config['a'], dict) + assert isinstance(config['a']['b'], dict) + assert isinstance(config['a']['b']['c'], dict) + + +def test_config_get_nested_with_override(): + """Test getting nested config with overrides.""" + config = config_utils.Config({'a': {'b': {'c': 1}}}) + + # Test simple override + value = config.get_nested(('a', 'b', 'c'), + default_value=None, + override_configs={'a': { + 'b': { + 'c': 2 + } + }}) + assert value == 2 + + # Test override with allowed keys + value = config.get_nested(('a', 'b', 'c'), + default_value=None, + override_configs={'a': { + 'b': { + 'c': 3 + } + }}, + allowed_override_keys=[('a', 'b', 'c')]) + assert value == 3 + + # Test override with disallowed keys + with pytest.raises(ValueError): + config.get_nested(('a', 'b', 'c'), + default_value=None, + override_configs={'a': { + 'b': { + 'c': 4 + } + }}, + disallowed_override_keys=[('a', 'b', 'c')]) + + +def test_merge_k8s_configs_with_image_pull_secrets(): + """Test merging Kubernetes configs with imagePullSecrets.""" + base_config = {'imagePullSecrets': [{'name': 'secret1'}]} + override_config = { + 'imagePullSecrets': [{ + 'name': 'secret2', + 'namespace': 'test' + }] + } + + config_utils.merge_k8s_configs(base_config, override_config) + assert len(base_config['imagePullSecrets']) == 1 + assert base_config['imagePullSecrets'][0]['name'] == 'secret2' + assert base_config['imagePullSecrets'][0]['namespace'] == 'test' + + +def test_config_override_with_allowed_keys(): + """Test config override with allowed keys restrictions.""" + base_config = config_utils.Config({ + 'aws': { + 'vpc_name': 'default-vpc', + 'security_group': 'default-sg' + }, + 'gcp': { + 'project_id': 'default-project' + } + }) + + override_config = { + 'aws': { + 'vpc_name': 'custom-vpc' + }, + 'gcp': { + 'project_id': 'custom-project' # This should fail + } + } + + # Only allow aws.vpc_name to be overridden + allowed_keys = [('aws', 'vpc_name')] + + # We raise error whenever the override key is not in the allowed keys. + with pytest.raises(ValueError, match='not in allowed override keys:'): + base_config.get_nested(('aws', 'vpc_name'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + + # Should raise error when trying to override disallowed key + with pytest.raises(ValueError, match='not in allowed override keys:'): + base_config.get_nested(('gcp', 'project_id'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + + allowed_keys = [('aws', 'vpc_name'), ('gcp', 'project_id')] + value = base_config.get_nested(('aws', 'vpc_name'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + assert value == 'custom-vpc' + + value = base_config.get_nested(('gcp', 'project_id'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + assert value == 'custom-project' + + override_config = { + 'aws': { + 'vpc_name': 'custom-vpc', + 'security_group': 'custom-sg' + } + } + with pytest.raises(ValueError, match='not in allowed override keys:'): + base_config.get_nested(('aws', 'vpc_name'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + + allowed_keys = [('aws', 'vpc_name'), ('aws', 'security_group')] + value = base_config.get_nested(('aws', 'security_group'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + assert value == 'custom-sg' + + allowed_keys = [('aws',)] + value = base_config.get_nested(('aws', 'vpc_name'), + default_value=None, + override_configs=override_config, + allowed_override_keys=allowed_keys) + assert value == 'custom-vpc' + + +def test_k8s_config_merge_with_multiple_volumes(): + """Test merging Kubernetes configs with multiple volume configurations.""" + base_config = { + 'volumes': [{ + 'name': 'vol1', + 'hostPath': '/path1' + }, { + 'name': 'vol2', + 'hostPath': '/path2' + }], + 'volumeMounts': [{ + 'name': 'vol1', + 'mountPath': '/mnt1' + }, { + 'name': 'vol2', + 'mountPath': '/mnt2' + }] + } + + override_config = { + 'volumes': [ + { + 'name': 'vol1', + 'hostPath': '/new-path1' + }, # Should update existing + { + 'name': 'vol3', + 'hostPath': '/path3' + } # Should append + ], + 'volumeMounts': [ + { + 'name': 'vol1', + 'mountPath': '/new-mnt1' + }, # Should update existing + { + 'name': 'vol3', + 'mountPath': '/mnt3' + } # Should append + ] + } + + config_utils.merge_k8s_configs(base_config, override_config) + + # Check volumes + assert len(base_config['volumes']) == 3 + vol1 = next(v for v in base_config['volumes'] if v['name'] == 'vol1') + assert vol1['hostPath'] == '/new-path1' + vol3 = next(v for v in base_config['volumes'] if v['name'] == 'vol3') + assert vol3['hostPath'] == '/path3' + + # Check volumeMounts + assert len(base_config['volumeMounts']) == 3 + mount1 = next(m for m in base_config['volumeMounts'] if m['name'] == 'vol1') + assert mount1['mountPath'] == '/new-mnt1' + mount3 = next(m for m in base_config['volumeMounts'] if m['name'] == 'vol3') + assert mount3['mountPath'] == '/mnt3' + + +def test_nested_config_override_precedence(): + """Test that config overrides follow correct precedence rules.""" + base_config = config_utils.Config({ + 'kubernetes': { + 'pod_config': { + 'metadata': { + 'labels': { + 'env': 'prod', + 'team': 'ml' + } + }, + 'spec': { + 'containers': [{ + 'resources': { + 'limits': { + 'cpu': '1', + 'memory': '1Gi' + } + } + }] + } + } + } + }) + + override_config = { + 'kubernetes': { + 'pod_config': { + 'metadata': { + 'labels': { + 'env': 'dev', # Should override + 'project': 'skypilot' # Should add + } + }, + 'spec': { + 'containers': [{ + 'resources': { + 'limits': { + 'memory': '2Gi' # Should override + } + } + }] + } + } + } + } + + # Get nested value with override + result = base_config.get_nested(('kubernetes', 'pod_config'), + default_value=None, + override_configs=override_config) + + # Check that labels were properly merged + assert result['metadata']['labels'] == { + 'env': 'dev', + 'team': 'ml', + 'project': 'skypilot' + } + + # Check that container resources were properly merged + container = result['spec']['containers'][0] + assert container['resources']['limits'] == {'cpu': '1', 'memory': '2Gi'} + + +def test_nested_config_override_with_nonexistent_key(): + """Test that config override with nonexistent key in base config.""" + base_config = config_utils.Config({}) + override_config = { + 'kubernetes': { + 'pod_config': { + 'metadata': { + 'labels': { + 'env': 'dev', + 'project': 'skypilot' + } + } + } + } + } + result = base_config.get_nested(('kubernetes', 'pod_config'), + default_value=None, + override_configs=override_config) + assert result == override_config['kubernetes']['pod_config'] diff --git a/tests/unit_tests/test_dag_utils.py b/tests/unit_tests/test_dag_utils.py index a083757800f..83eb48ec766 100644 --- a/tests/unit_tests/test_dag_utils.py +++ b/tests/unit_tests/test_dag_utils.py @@ -6,8 +6,8 @@ import sky from sky import jobs -from sky.utils import common_utils from sky.utils import dag_utils +from sky.utils import registry def test_jobs_recovery_fill_default_values(): @@ -23,8 +23,8 @@ def test_jobs_recovery_fill_default_values(): resources = list(dag.tasks[0].resources) assert len(resources) == 1 - assert resources[0].job_recovery[ - 'strategy'] == jobs.DEFAULT_RECOVERY_STRATEGY + assert resources[0].job_recovery['strategy'].lower( + ) == registry.JOBS_RECOVERY_STRATEGY_REGISTRY.default task_str = textwrap.dedent("""\ resources: @@ -40,8 +40,8 @@ def test_jobs_recovery_fill_default_values(): resources = list(dag.tasks[0].resources) assert len(resources) == 1 - assert resources[0].job_recovery[ - 'strategy'] == jobs.DEFAULT_RECOVERY_STRATEGY + assert resources[0].job_recovery['strategy'].lower( + ) == registry.JOBS_RECOVERY_STRATEGY_REGISTRY.default assert resources[0].job_recovery['max_restarts_on_errors'] == 3 task_str = textwrap.dedent(f"""\ @@ -75,8 +75,8 @@ def test_jobs_recovery_fill_default_values(): resources = list(dag.tasks[0].resources) assert len(resources) == 1 - assert resources[0].job_recovery[ - 'strategy'] == jobs.DEFAULT_RECOVERY_STRATEGY + assert resources[0].job_recovery['strategy'].lower( + ) == registry.JOBS_RECOVERY_STRATEGY_REGISTRY.default task_str = textwrap.dedent("""\ resources: diff --git a/tests/unit_tests/test_recovery_strategy.py b/tests/unit_tests/test_recovery_strategy.py index da8e8142da0..6296bc01ec4 100644 --- a/tests/unit_tests/test_recovery_strategy.py +++ b/tests/unit_tests/test_recovery_strategy.py @@ -4,7 +4,7 @@ from sky.jobs import recovery_strategy -@mock.patch('sky.down') +@mock.patch('sky.core.down') @mock.patch('sky.usage.usage_lib.messages.usage.set_internal') def test_terminate_cluster_retry_on_value_error(mock_set_internal, mock_sky_down) -> None: @@ -30,7 +30,7 @@ def test_terminate_cluster_retry_on_value_error(mock_set_internal, assert mock_set_internal.call_count == 3 -@mock.patch('sky.down') +@mock.patch('sky.core.down') @mock.patch('sky.usage.usage_lib.messages.usage.set_internal') def test_terminate_cluster_handles_nonexistent_cluster(mock_set_internal, mock_sky_down) -> None: diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 65c90544f49..1e16622e55f 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -39,14 +39,16 @@ def _run_label_test(allowed_labels: Dict[str, str], cloud: clouds.Cloud = None): """Run a test for labels with the given allowed and invalid labels.""" r_allowed = Resources(cloud=cloud, labels=allowed_labels) # Should pass + r_allowed.validate() assert r_allowed.labels == allowed_labels, ('Allowed labels ' 'should be the same') # Check for each invalid label for invalid_label, value in invalid_labels.items(): l = {invalid_label: value} + r = Resources(cloud=cloud, labels=l) with pytest.raises(ValueError): - _ = Resources(cloud=cloud, labels=l) + r.validate() assert False, (f'Resources were initialized with ' f'invalid label {invalid_label}={value}')