Skip to content

Commit

Permalink
[Config] Override skypilot config with client-side config (#37)
Browse files Browse the repository at this point in the history
* Allow skypilot config override

* fix

* add disallowed config

* format

* fix

* fixes

* naming

* add unit tests

* Add another test

* fixes

* add unittests

* Fix unit tests

* format
  • Loading branch information
Michaelvll authored Dec 4, 2024
1 parent 13b95d8 commit 4c6e9ca
Show file tree
Hide file tree
Showing 20 changed files with 685 additions and 193 deletions.
25 changes: 13 additions & 12 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ on:
branches:
- master
- 'releases/**'
- restapi
merge_group:

jobs:
Expand All @@ -19,18 +20,18 @@ jobs:
python-version: [3.8]
test-path:
- tests/unit_tests
- tests/test_api.py
- tests/test_cli.py
- tests/test_config.py
- tests/test_global_user_state.py
- tests/test_jobs.py
- tests/test_list_accelerators.py
- tests/test_optimizer_dryruns.py
- tests/test_optimizer_random_dag.py
- tests/test_storage.py
- tests/test_wheels.py
- tests/test_jobs_and_serve.py
- tests/test_yaml_parser.py
# - tests/test_api.py
# - tests/test_cli.py
# - tests/test_config.py
# - tests/test_global_user_state.py
# - tests/test_jobs.py
# - tests/test_list_accelerators.py
# - tests/test_optimizer_dryruns.py
# - tests/test_optimizer_random_dag.py
# - tests/test_storage.py
# - tests/test_wheels.py
# - tests/test_jobs_and_serve.py
# - tests/test_yaml_parser.py
runs-on: ubuntu-latest
steps:
- name: Checkout repository
Expand Down
4 changes: 2 additions & 2 deletions docs/source/cloud-setup/policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ a request should be rejected, the policy should raise an exception.

The ``sky.Config`` and ``sky.RequestOptions`` classes are defined as follows:

.. literalinclude:: ../../../sky/skypilot_config.py
.. literalinclude:: ../../../sky/utils/config_utils.py
:language: python
:pyobject: Config
:caption: `Config Class <https://github.com/skypilot-org/skypilot/blob/master/sky/skypilot_config.py>`_
:caption: `Config Class <https://github.com/skypilot-org/skypilot/blob/master/sky/utils/config_utils.py>`_


.. literalinclude:: ../../../sky/admin_policy.py
Expand Down
2 changes: 1 addition & 1 deletion sky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]):
from sky.optimizer import Optimizer
from sky.resources import Resources
from sky.skylet.job_lib import JobStatus
from sky.skypilot_config import Config
from sky.task import Task
from sky.utils.common import OptimizeTarget
from sky.utils.config_utils import Config
from sky.utils.status_lib import ClusterStatus

# Aliases.
Expand Down
6 changes: 5 additions & 1 deletion sky/api/requests/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sky import global_user_state
from sky import models
from sky import sky_logging
from sky import skypilot_config
from sky.api.requests import payloads
from sky.api.requests import requests
from sky.api.requests.queues import mp_queue
Expand Down Expand Up @@ -163,6 +164,7 @@ def restore_output(original_stdout, original_stderr):
# Store copies of the original stdout and stderr file descriptors
original_stdout, original_stderr = redirect_output(f)
try:

os.environ.update(request_body.env_vars)
user = models.User(
id=request_body.env_vars[constants.USER_ID_ENV_VAR],
Expand All @@ -172,7 +174,9 @@ def restore_output(original_stdout, original_stderr):
# Force color to be enabled.
os.environ['CLICOLOR_FORCE'] = '1'
common.reload()
return_value = func(**request_body.to_kwargs())
with skypilot_config.override_skypilot_config(
request_body.override_skypilot_config):
return_value = func(**request_body.to_kwargs())
except Exception as e: # pylint: disable=broad-except
with ux_utils.enable_traceback():
stacktrace = traceback.format_exc()
Expand Down
18 changes: 17 additions & 1 deletion sky/api/requests/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pydantic

from sky import serve
from sky import skypilot_config
from sky.api import common
from sky.skylet import constants
from sky.utils import common as common_lib
Expand All @@ -19,17 +20,31 @@
def request_body_env_vars() -> dict:
env_vars = {}
for env_var in os.environ:
if env_var.startswith('SKYPILOT_'):
if env_var.startswith(constants.SKYPILOT_ENV_VAR_PREFIX):
env_vars[env_var] = os.environ[env_var]
env_vars[constants.USER_ID_ENV_VAR] = common_utils.get_user_hash()
env_vars[constants.USER_ENV_VAR] = getpass.getuser()
# Remove the path to config file, as the config content is included in the
# request body and will be merged with the config on the server side.
env_vars.pop(skypilot_config.ENV_VAR_SKYPILOT_CONFIG, None)
return env_vars


def get_override_skypilot_config_from_client() -> Dict[str, Any]:
"""Returns the override configs from the client."""
config = skypilot_config.to_dict()
# Remove the API server config, as we should not specify the SkyPilot
# server endpoint on the server side.
config.pop('api_server', None)
return config


class RequestBody(pydantic.BaseModel):
"""The request body for the SkyPilot API."""
env_vars: Dict[str, str] = request_body_env_vars()
entrypoint_command: str = common_utils.get_pretty_entry_point()
override_skypilot_config: Optional[Dict[
str, Any]] = get_override_skypilot_config_from_client()

def to_kwargs(self) -> Dict[str, Any]:
"""Convert the request body to a kwargs dictionary on API server.
Expand All @@ -40,6 +55,7 @@ def to_kwargs(self) -> Dict[str, Any]:
kwargs = self.model_dump()
kwargs.pop('env_vars')
kwargs.pop('entrypoint_command')
kwargs.pop('override_skypilot_config')
return kwargs


Expand Down
3 changes: 2 additions & 1 deletion sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.lambda_cloud import lambda_utils
from sky.utils import common_utils
from sky.utils import config_utils
from sky.utils import kubernetes_enums
from sky.utils import subprocess_utils
from sky.utils import ux_utils
Expand Down Expand Up @@ -402,7 +403,7 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
}
custom_metadata = skypilot_config.get_nested(
('kubernetes', 'custom_metadata'), {})
kubernetes_utils.merge_dicts(custom_metadata, secret_metadata)
config_utils.merge_k8s_configs(secret_metadata, custom_metadata)

secret = k8s.client.V1Secret(
metadata=k8s.client.V1ObjectMeta(**secret_metadata),
Expand Down
57 changes: 7 additions & 50 deletions sky/provision/kubernetes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sky.provision.kubernetes import network_utils
from sky.skylet import constants
from sky.utils import common_utils
from sky.utils import config_utils
from sky.utils import env_options
from sky.utils import kubernetes_enums
from sky.utils import schemas
Expand Down Expand Up @@ -1675,50 +1676,6 @@ def get_endpoint_debug_message() -> str:
debug_cmd=debug_cmd)


def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]):
"""Merge two dictionaries into the destination dictionary.
Updates nested dictionaries instead of replacing them.
If a list is encountered, it will be appended to the destination list.
An exception is when the key is 'containers', in which case the
first container in the list will be fetched and merge_dict will be
called on it with the first container in the destination list.
"""
for key, value in source.items():
if isinstance(value, dict) and key in destination:
merge_dicts(value, destination[key])
elif isinstance(value, list) and key in destination:
assert isinstance(destination[key], list), \
f'Expected {key} to be a list, found {destination[key]}'
if key in ['containers', 'imagePullSecrets']:
# If the key is 'containers' or 'imagePullSecrets, we take the
# first and only container/secret in the list and merge it, as
# we only support one container per pod.
assert len(value) == 1, \
f'Expected only one container, found {value}'
merge_dicts(value[0], destination[key][0])
elif key in ['volumes', 'volumeMounts']:
# If the key is 'volumes' or 'volumeMounts', we search for
# item with the same name and merge it.
for new_volume in value:
new_volume_name = new_volume.get('name')
if new_volume_name is not None:
destination_volume = next(
(v for v in destination[key]
if v.get('name') == new_volume_name), None)
if destination_volume is not None:
merge_dicts(new_volume, destination_volume)
else:
destination[key].append(new_volume)
else:
destination[key].extend(value)
else:
if destination is None:
destination = {}
destination[key] = value


def combine_pod_config_fields(
cluster_yaml_path: str,
cluster_config_overrides: Dict[str, Any],
Expand Down Expand Up @@ -1771,12 +1728,12 @@ def combine_pod_config_fields(
override_configs={})
override_pod_config = (cluster_config_overrides.get('kubernetes', {}).get(
'pod_config', {}))
merge_dicts(override_pod_config, kubernetes_config)
config_utils.merge_k8s_configs(kubernetes_config, override_pod_config)

# Merge the kubernetes config into the YAML for both head and worker nodes.
merge_dicts(
kubernetes_config,
yaml_obj['available_node_types']['ray_head_default']['node_config'])
config_utils.merge_k8s_configs(
yaml_obj['available_node_types']['ray_head_default']['node_config'],
kubernetes_config)

# Write the updated YAML back to the file
common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
Expand Down Expand Up @@ -1810,7 +1767,7 @@ def combine_metadata_fields(cluster_yaml_path: str) -> None:
]

for destination in combination_destinations:
merge_dicts(custom_metadata, destination)
config_utils.merge_k8s_configs(destination, custom_metadata)

# Write the updated YAML back to the file
common_utils.dump_yaml(cluster_yaml_path, yaml_obj)
Expand All @@ -1823,7 +1780,7 @@ def merge_custom_metadata(original_metadata: Dict[str, Any]) -> None:
"""
custom_metadata = skypilot_config.get_nested(
('kubernetes', 'custom_metadata'), {})
merge_dicts(custom_metadata, original_metadata)
config_utils.merge_k8s_configs(original_metadata, custom_metadata)


def check_nvidia_runtime_class(context: Optional[str] = None) -> bool:
Expand Down
33 changes: 20 additions & 13 deletions sky/skylet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,20 @@
'export PATH='
f'$(echo $PATH | sed "s|$(echo ~)/{SKY_REMOTE_PYTHON_ENV_NAME}/bin:||")')

# Prefix for SkyPilot environment variables
SKYPILOT_ENV_VAR_PREFIX = 'SKYPILOT_'

# The name for the environment variable that stores the unique ID of the
# current task. This will stay the same across multiple recoveries of the
# same managed task.
TASK_ID_ENV_VAR = 'SKYPILOT_TASK_ID'
TASK_ID_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}TASK_ID'
# This environment variable stores a '\n'-separated list of task IDs that
# are within the same managed job (DAG). This can be used by the user to
# retrieve the task IDs of any tasks that are within the same managed job.
# This environment variable is pre-assigned before any task starts
# running within the same job, and will remain constant throughout the
# lifetime of the job.
TASK_ID_LIST_ENV_VAR = 'SKYPILOT_TASK_IDS'
TASK_ID_LIST_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}TASK_IDS'

# The version of skylet. MUST bump this version whenever we need the skylet to
# be restarted on existing clusters updated with the new version of SkyPilot,
Expand All @@ -90,9 +93,9 @@
# Docker default options
DEFAULT_DOCKER_CONTAINER_NAME = 'sky_container'
DEFAULT_DOCKER_PORT = 10022
DOCKER_USERNAME_ENV_VAR = 'SKYPILOT_DOCKER_USERNAME'
DOCKER_PASSWORD_ENV_VAR = 'SKYPILOT_DOCKER_PASSWORD'
DOCKER_SERVER_ENV_VAR = 'SKYPILOT_DOCKER_SERVER'
DOCKER_USERNAME_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}DOCKER_USERNAME'
DOCKER_PASSWORD_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}DOCKER_PASSWORD'
DOCKER_SERVER_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}DOCKER_SERVER'
DOCKER_LOGIN_ENV_VARS = {
DOCKER_USERNAME_ENV_VAR,
DOCKER_PASSWORD_ENV_VAR,
Expand Down Expand Up @@ -229,12 +232,12 @@
# is mainly used to make sure sky commands runs on a VM launched by SkyPilot
# will be recognized as the same user (e.g., jobs controller or sky serve
# controller).
USER_ID_ENV_VAR = 'SKYPILOT_USER_ID'
USER_ID_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}USER_ID'

# The name for the environment variable that stores SkyPilot user name.
# Similar to USER_ID_ENV_VAR, this is mainly used to make sure sky commands
# runs on a VM launched by SkyPilot will be recognized as the same user.
USER_ENV_VAR = 'SKYPILOT_USER'
USER_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}USER'

# In most clouds, cluster names can only contain lowercase letters, numbers
# and hyphens. We use this regex to validate the cluster name.
Expand Down Expand Up @@ -269,27 +272,31 @@

# The name for the environment variable that stores the URL of the SkyPilot
# API server.
SKY_API_SERVER_URL_ENV_VAR = 'SKYPILOT_API_SERVER_ENDPOINT'
SKY_API_SERVER_URL_ENV_VAR = f'{SKYPILOT_ENV_VAR_PREFIX}API_SERVER_ENDPOINT'

# SkyPilot environment variables
SKYPILOT_NUM_NODES = 'SKYPILOT_NUM_NODES'
SKYPILOT_NODE_IPS = 'SKYPILOT_NODE_IPS'
SKYPILOT_NUM_GPUS_PER_NODE = 'SKYPILOT_NUM_GPUS_PER_NODE'
SKYPILOT_NODE_RANK = 'SKYPILOT_NODE_RANK'
SKYPILOT_NUM_NODES = f'{SKYPILOT_ENV_VAR_PREFIX}NUM_NODES'
SKYPILOT_NODE_IPS = f'{SKYPILOT_ENV_VAR_PREFIX}NODE_IPS'
SKYPILOT_NUM_GPUS_PER_NODE = f'{SKYPILOT_ENV_VAR_PREFIX}NUM_GPUS_PER_NODE'
SKYPILOT_NODE_RANK = f'{SKYPILOT_ENV_VAR_PREFIX}NODE_RANK'

# Placeholder for the SSH user in proxy command, replaced when the ssh_user is
# known after provisioning.
SKY_SSH_USER_PLACEHOLDER = 'skypilot:ssh_user'

# The keys that can be overridden in the `~/.sky/config.yaml` file. The
# overrides are specified in task YAMLs.
OVERRIDEABLE_CONFIG_KEYS: List[Tuple[str, ...]] = [
OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
('docker', 'run_options'),
('nvidia_gpus', 'disable_ecc'),
('kubernetes', 'pod_config'),
('kubernetes', 'provision_timeout'),
('gcp', 'managed_instance_group'),
]
DISALLOWED_CLIENT_OVERRIDE_KEYS: List[Tuple[str, ...]] = [
('admin_policy',),
('api_server',),
]

# Constants for Azure blob storage
WAIT_FOR_STORAGE_ACCOUNT_CREATION = 60
Expand Down
Loading

0 comments on commit 4c6e9ca

Please sign in to comment.