Skip to content

Commit

Permalink
Assert for override configs specification
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelvll committed Jul 8, 2024
1 parent 70f0281 commit d050328
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 24 deletions.
12 changes: 2 additions & 10 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,6 @@

_DEFAULT_DISK_SIZE_GB = 256

OVERRIDEABLE_CONFIG_KEYS = [
('docker',),
('nvidia_gpus',),
('kubernetes', 'pod_config'),
('kubernetes', 'provision_timeout'),
('gcp', 'managed_instance_group'),
]


class Resources:
"""Resources: compute requirements of Tasks.
Expand Down Expand Up @@ -1033,15 +1025,15 @@ def make_deploy_variables(self, cluster_name_on_cloud: str,
if (skypilot_config.get_nested(
('nvidia_gpus', 'disable_ecc'),
False,
override_configs=self._cluster_config_overrides) and
override_configs=self.cluster_config_overrides) and
self.accelerators is not None):
initial_setup_commands = [constants.DISABLE_GPU_ECC_COMMAND]

# Docker run options
docker_run_options = skypilot_config.get_nested(
('docker', 'run_options'),
default_value=[],
override_configs=self._cluster_config_overrides)
override_configs=self.cluster_config_overrides)
if isinstance(docker_run_options, str):
docker_run_options = [docker_run_options]
if docker_run_options and isinstance(self.cloud, clouds.Kubernetes):
Expand Down
12 changes: 12 additions & 0 deletions sky/skylet/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Constants for SkyPilot."""
from typing import List, Tuple

from packaging import version

import sky
Expand Down Expand Up @@ -261,3 +263,13 @@
# Placeholder for the SSH user in proxy command, replaced when the ssh_user is
# known after provisioning.
SKY_SSH_USER_PLACEHOLDER = 'skypilot:ssh_user'

# The keys that can be overridden in the `~/.sky/config.yaml` file. The
# overrides are specified in task YAMLs.
OVERRIDEABLE_CONFIG_KEYS: List[Tuple[str, ...]] = [
('docker', 'run_options'),
('nvidia_gpus', 'disable_ecc'),
('kubernetes', 'pod_config'),
('kubernetes', 'provision_timeout'),
('gcp', 'managed_instance_group'),
]
45 changes: 35 additions & 10 deletions sky/skypilot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@
import copy
import os
import pprint
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional, Tuple

import yaml

from sky import sky_logging
from sky.skylet import constants
from sky.utils import common_utils
from sky.utils import schemas
from sky.utils import ux_utils
Expand All @@ -73,7 +74,7 @@
logger = sky_logging.init_logger(__name__)

# The loaded config.
_dict = None
_dict: Optional[Dict[str, Any]] = None
_loaded_config_path = None


Expand All @@ -98,14 +99,38 @@ def get_nested(keys: Iterable[str],
If any key is not found, or any intermediate key does not point to a dict
value, returns 'default_value'.
When 'keys' is within OVERRIDEABLE_CONFIG_KEYS, 'override_configs' must be
provided (can be empty). Otherwise, 'override_configs' must not be provided.
Args:
keys: A tuple of strings representing the nested keys.
default_value: The default value to return if the key is not found.
override_configs: A dict of override configs with the same schema as
the config file, but only containing the keys to override.
Returns:
The value of the nested key, or 'default_value' if not found.
"""
# TODO (zhwu): Verify that the override_configs is provided when keys is
# within resources.OVERRIDEABLE_CONFIG_KEYS.
if _dict is None:
if override_configs is not None:
return _get_nested(override_configs, keys, default_value)
return default_value
config = _recursive_update(copy.deepcopy(_dict), override_configs or {})
assert (
keys in constants.OVERRIDEABLE_CONFIG_KEYS or
override_configs is not None
), (f'Override configs must be provided when keys {keys} is within '
'constants.OVERRIDEABLE_CONFIG_KEYS: '
f'{constants.OVERRIDEABLE_CONFIG_KEYS}'
)
assert (
keys in constants.OVERRIDEABLE_CONFIG_KEYS or override_configs is None
), (f'Override configs must not be provided when keys {keys} is not within '
'constants.OVERRIDEABLE_CONFIG_KEYS: '
f'{constants.OVERRIDEABLE_CONFIG_KEYS}'
)
config: Dict[str, Any] = {}
if _dict is not None:
config = copy.deepcopy(_dict)
if override_configs is None:
override_configs = {}
config = _recursive_update(config, override_configs)
return _get_nested(config, keys, default_value)


Expand All @@ -121,7 +146,7 @@ def _recursive_update(base_config: Dict[str, Any],
return base_config


def set_nested(keys: Iterable[str], value: Any) -> Dict[str, Any]:
def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]:
"""Returns a deep-copied config with the nested key set to value.
Like get_nested(), if any key is not found, this will not raise an error.
Expand Down
9 changes: 5 additions & 4 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
https://json-schema.org/
"""
import enum
from typing import Any, Dict
from typing import Any, Dict, List, Tuple

from sky.skylet import constants


def _check_not_both_fields_present(field1: str, field2: str):
Expand Down Expand Up @@ -375,7 +377,7 @@ def get_service_schema():
}


def _filter_schema(schema: dict, keys_to_keep: dict) -> dict:
def _filter_schema(schema: dict, keys_to_keep: List[Tuple[str, ...]]) -> dict:
"""Recursively filter a schema to include only certain keys.
Args:
Expand Down Expand Up @@ -424,9 +426,8 @@ def keep_keys(current_schema: dict, current_path_dict: dict,


def _experimental_task_schema() -> dict:
from sky import resources # pylint: disable=import-outside-toplevel
config_override_schema = _filter_schema(get_config_schema(),
resources.OVERRIDEABLE_CONFIG_KEYS)
constants.OVERRIDEABLE_CONFIG_KEYS)
return {
'experimental': {
'type': 'object',
Expand Down

0 comments on commit d050328

Please sign in to comment.