Skip to content

Commit

Permalink
Implement config validation (#2645)
Browse files Browse the repository at this point in the history
* Implement config validation

* Add tests

* Throw error on invalid config

* Fix bug

* Log on error

* Throw exception on invalid

* Refactor kube enums

* Fix validation bug

* Add test for invalid enum in config

* Add config path to error message

* Fix schema bug and add test

* Fix clouds list

* Update config docs

* Move k8s docs

* Update schema versions

* Update k8s config
  • Loading branch information
iojw authored Oct 24, 2023
1 parent 048f45a commit dd9450d
Show file tree
Hide file tree
Showing 15 changed files with 373 additions and 132 deletions.
30 changes: 30 additions & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,33 @@ Available fields and semantics:
# Only one element is allowed in this list, as GCP disallows multiple
# specific_reservations in a single request.
- projects/my-project/reservations/my-reservation
# Advanced Kubernetes configurations (optional).
kubernetes:
# The networking mode for accessing SSH jump pod (optional).
# This must be either: 'nodeport' or 'portforward'. If not specified, defaults to 'portforward'.
#
# nodeport: Exposes the jump pod SSH service on a static port number on each Node, allowing external access to using <NodeIP>:<NodePort>. Using this mode requires opening multiple ports on nodes in the Kubernetes cluster.
# portforward: Uses `kubectl port-forward` to create a tunnel and directly access the jump pod SSH service in the Kubernetes cluster. Does not require opening ports the cluster nodes and is more secure. 'portforward' is used as default if 'networking' is not specified.
networking: portforward
# Advanced OCI configurations (optional).
oci:
# A dict mapping region names to region-specific configurations, or `default` for the default configuration.
default:
# The OCID of the profile to use for launching instances (optional).
oci_config_profile: DEFAULT
# The OCID of the compartment to use for launching instances (optional).
compartment_ocid: ocid1.compartment.oc1..aaaaaaaahr7aicqtodxmcfor6pbqn3hvsngpftozyxzqw36gj4kh3w3kkj4q
# The image tag to use for launching general instances (optional).
image_tag_general: skypilot:cpu-ubuntu-2004
# The image tag to use for launching GPU instances (optional).
image_tag_gpu: skypilot:gpu-ubuntu-2004
ap-seoul-1:
# The OCID of the subnet to use for instances (optional).
vcn_subnet: ocid1.subnet.oc1.ap-seoul-1.aaaaaaaa5c6wndifsij6yfyfehmi3tazn6mvhhiewqmajzcrlryurnl7nuja
us-ashburn-1:
vcn_subnet: ocid1.subnet.oc1.iad.aaaaaaaafbj7i3aqc4ofjaapa5edakde6g4ea2yaslcsay32cthp7qo55pxa
11 changes: 6 additions & 5 deletions sky/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sky.adaptors import ibm
from sky.clouds.utils import lambda_utils
from sky.utils import common_utils
from sky.utils import kubernetes_enums
from sky.utils import kubernetes_utils
from sky.utils import subprocess_utils
from sky.utils import ux_utils
Expand Down Expand Up @@ -351,12 +352,12 @@ def _get_unique_key_name():
def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:
# Default ssh session is established with kubectl port-forwarding with
# ClusterIP service.
nodeport_mode = kubernetes_utils.KubernetesNetworkingMode.NODEPORT
port_forward_mode = kubernetes_utils.KubernetesNetworkingMode.PORTFORWARD
nodeport_mode = kubernetes_enums.KubernetesNetworkingMode.NODEPORT
port_forward_mode = kubernetes_enums.KubernetesNetworkingMode.PORTFORWARD
network_mode_str = skypilot_config.get_nested(('kubernetes', 'networking'),
port_forward_mode.value)
try:
network_mode = kubernetes_utils.KubernetesNetworkingMode.from_str(
network_mode = kubernetes_enums.KubernetesNetworkingMode.from_str(
network_mode_str)
except ValueError as e:
# Add message saying "Please check: ~/.sky/config.yaml" to the error
Expand Down Expand Up @@ -392,14 +393,14 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]:

ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME
if network_mode == nodeport_mode:
service_type = kubernetes_utils.KubernetesServiceType.NODEPORT
service_type = kubernetes_enums.KubernetesServiceType.NODEPORT
elif network_mode == port_forward_mode:
kubernetes_utils.check_port_forward_mode_dependencies()
# Using `kubectl port-forward` creates a direct tunnel to jump pod and
# does not require opening any ports on Kubernetes nodes. As a result,
# the service can be a simple ClusterIP service which we access with
# `kubectl port-forward`.
service_type = kubernetes_utils.KubernetesServiceType.CLUSTERIP
service_type = kubernetes_enums.KubernetesServiceType.CLUSTERIP
else:
# This should never happen because we check for this in from_str above.
raise ValueError(f'Unsupported networking mode: {network_mode_str}')
Expand Down
57 changes: 0 additions & 57 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Util constants/functions for the backends."""
from datetime import datetime
import difflib
import enum
import getpass
import json
Expand All @@ -19,7 +18,6 @@
import colorama
import filelock
import jinja2
import jsonschema
from packaging import version
import requests
from requests import adapters
Expand Down Expand Up @@ -53,7 +51,6 @@
from sky.utils import timeline
from sky.utils import tpu_utils
from sky.utils import ux_utils
from sky.utils import validator

if typing.TYPE_CHECKING:
from sky import resources
Expand Down Expand Up @@ -2758,60 +2755,6 @@ def stop_handler(signum, frame):
raise KeyboardInterrupt(exceptions.SIGTSTP_CODE)


def validate_schema(obj, schema, err_msg_prefix='', skip_none=True):
"""Validates an object against a given JSON schema.
Args:
obj: The object to validate.
schema: The JSON schema against which to validate the object.
err_msg_prefix: The string to prepend to the error message if
validation fails.
skip_none: If True, removes fields with value None from the object
before validation. This is useful for objects that will never contain
None because yaml.safe_load() loads empty fields as None.
Raises:
ValueError: if the object does not match the schema.
"""
if skip_none:
obj = {k: v for k, v in obj.items() if v is not None}
err_msg = None
try:
validator.SchemaValidator(schema).validate(obj)
except jsonschema.ValidationError as e:
if e.validator == 'additionalProperties':
if tuple(e.schema_path) == ('properties', 'envs',
'additionalProperties'):
# Hack. Here the error is Task.envs having some invalid keys. So
# we should not print "unsupported field".
#
# This will print something like:
# 'hello world' does not match any of the regexes: <regex>
err_msg = (err_msg_prefix +
'The `envs` field contains invalid keys:\n' +
e.message)
else:
err_msg = err_msg_prefix + 'The following fields are invalid:'
known_fields = set(e.schema.get('properties', {}).keys())
for field in e.instance:
if field not in known_fields:
most_similar_field = difflib.get_close_matches(
field, known_fields, 1)
if most_similar_field:
err_msg += (f'\nInstead of {field!r}, did you mean '
f'{most_similar_field[0]!r}?')
else:
err_msg += f'\nFound unsupported field {field!r}.'
else:
# Example e.json_path value: '$.resources'
err_msg = (err_msg_prefix + e.message +
f'. Check problematic field(s): {e.json_path}')

if err_msg:
with ux_utils.print_exception_no_traceback():
raise ValueError(err_msg)


def check_public_cloud_enabled():
"""Checks if any of the public clouds is enabled.
Expand Down
6 changes: 3 additions & 3 deletions sky/backends/onprem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def check_and_get_local_clusters(suppress_error: bool = False) -> List[str]:
with open(path, 'r') as f:
yaml_config = yaml.safe_load(f)
if not suppress_error:
backend_utils.validate_schema(yaml_config,
schemas.get_cluster_schema(),
'Invalid cluster YAML: ')
common_utils.validate_schema(yaml_config,
schemas.get_cluster_schema(),
'Invalid cluster YAML: ')
user_config = yaml_config['auth']
cluster_name = yaml_config['cluster']['name']
sky_local_path = SKY_USER_LOCAL_CONFIG_PATH
Expand Down
4 changes: 2 additions & 2 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3515,8 +3515,8 @@ def admin_deploy(clusterspec_yaml: str):
clusterspec_yaml = ' '.join(clusterspec_yaml)
assert clusterspec_yaml
is_yaml, yaml_config = _check_yaml(clusterspec_yaml)
backend_utils.validate_schema(yaml_config, schemas.get_cluster_schema(),
'Invalid cluster YAML: ')
common_utils.validate_schema(yaml_config, schemas.get_cluster_schema(),
'Invalid cluster YAML: ')
if not is_yaml:
raise ValueError('Must specify cluster config')
assert yaml_config is not None, (is_yaml, yaml_config)
Expand Down
11 changes: 6 additions & 5 deletions sky/clouds/service_catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
from sky.clouds.service_catalog import common

CloudFilter = Optional[Union[List[str], str]]
_ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci')
ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
'kubernetes')


def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs):
if clouds is None:
clouds = list(_ALL_CLOUDS)
clouds = list(ALL_CLOUDS)

# TODO(hemil): Remove this once the common service catalog
# functions are refactored from clouds/kubernetes.py to
# kubernetes_catalog.py and add kubernetes to _ALL_CLOUDS
if method_name == 'list_accelerators':
clouds.append('kubernetes')
# kubernetes_catalog.py
if method_name != 'list_accelerators':
clouds.remove('kubernetes')

single = isinstance(clouds, str)
if single:
Expand Down
6 changes: 3 additions & 3 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from sky.adaptors import cloudflare
from sky.adaptors import gcp
from sky.adaptors import ibm
from sky.backends import backend_utils
from sky.data import data_transfer
from sky.data import data_utils
from sky.data import mounting_utils
from sky.data import storage_utils
from sky.data.data_utils import Rclone
from sky.utils import common_utils
from sky.utils import rich_utils
from sky.utils import schemas
from sky.utils import ux_utils
Expand Down Expand Up @@ -865,8 +865,8 @@ def warn_for_git_dir(source: str):

@classmethod
def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage':
backend_utils.validate_schema(config, schemas.get_storage_schema(),
'Invalid storage YAML: ')
common_utils.validate_schema(config, schemas.get_storage_schema(),
'Invalid storage YAML: ')

name = config.pop('name', None)
source = config.pop('source', None)
Expand Down
6 changes: 3 additions & 3 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from sky import sky_logging
from sky import skypilot_config
from sky import spot
from sky.backends import backend_utils
from sky.clouds import service_catalog
from sky.provision import docker_utils
from sky.skylet import constants
from sky.utils import accelerator_registry
from sky.utils import common_utils
from sky.utils import log_utils
from sky.utils import resources_utils
from sky.utils import schemas
Expand Down Expand Up @@ -1109,8 +1109,8 @@ def from_yaml_config(cls, config: Optional[Dict[str, str]]) -> 'Resources':
if config is None:
return Resources()

backend_utils.validate_schema(config, schemas.get_resources_schema(),
'Invalid resources YAML: ')
common_utils.validate_schema(config, schemas.get_resources_schema(),
'Invalid resources YAML: ')

resources_fields = {}
resources_fields['cloud'] = clouds.CLOUD_REGISTRY.from_str(
Expand Down
7 changes: 7 additions & 0 deletions sky/skypilot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from sky import sky_logging
from sky.clouds import cloud_registry
from sky.utils import common_utils
from sky.utils import schemas

# The config path is discovered in this order:
#
Expand Down Expand Up @@ -160,6 +161,12 @@ def _try_load_config() -> None:
logger.debug(f'Config loaded:\n{pprint.pformat(_dict)}')
except yaml.YAMLError as e:
logger.error(f'Error in loading config file ({config_path}):', e)
if _dict is not None:
common_utils.validate_schema(
_dict,
schemas.get_config_schema(),
f'Invalid config YAML ({config_path}): ',
skip_none=False)

for cloud in cloud_registry.CLOUD_REGISTRY:
_syntax_check_for_ssh_proxy_command(cloud)
Expand Down
4 changes: 2 additions & 2 deletions sky/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ def from_yaml_config(
if envs is not None and isinstance(envs, dict):
config['envs'] = {str(k): str(v) for k, v in envs.items()}

backend_utils.validate_schema(config, schemas.get_task_schema(),
'Invalid task YAML: ')
common_utils.validate_schema(config, schemas.get_task_schema(),
'Invalid task YAML: ')

# Fill in any Task.envs into file_mounts (src/dst paths, storage
# name/source).
Expand Down
58 changes: 58 additions & 0 deletions sky/utils/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utils shared between all of sky"""

import difflib
import functools
import getpass
import hashlib
Expand All @@ -16,10 +17,13 @@
import uuid

import colorama
import jsonschema
import yaml

from sky import sky_logging
from sky.skylet import constants
from sky.utils import ux_utils
from sky.utils import validator

_USER_HASH_FILE = os.path.expanduser('~/.sky/user_hash')
USER_HASH_LENGTH = 8
Expand Down Expand Up @@ -489,3 +493,57 @@ def format_float(num: Union[float, int], precision: int = 1) -> str:
if isinstance(num, int):
return str(num)
return '{:.0f}'.format(num) if num.is_integer() else f'{num:.{precision}f}'


def validate_schema(obj, schema, err_msg_prefix='', skip_none=True):
"""Validates an object against a given JSON schema.
Args:
obj: The object to validate.
schema: The JSON schema against which to validate the object.
err_msg_prefix: The string to prepend to the error message if
validation fails.
skip_none: If True, removes fields with value None from the object
before validation. This is useful for objects that will never contain
None because yaml.safe_load() loads empty fields as None.
Raises:
ValueError: if the object does not match the schema.
"""
if skip_none:
obj = {k: v for k, v in obj.items() if v is not None}
err_msg = None
try:
validator.SchemaValidator(schema).validate(obj)
except jsonschema.ValidationError as e:
if e.validator == 'additionalProperties':
if tuple(e.schema_path) == ('properties', 'envs',
'additionalProperties'):
# Hack. Here the error is Task.envs having some invalid keys. So
# we should not print "unsupported field".
#
# This will print something like:
# 'hello world' does not match any of the regexes: <regex>
err_msg = (err_msg_prefix +
'The `envs` field contains invalid keys:\n' +
e.message)
else:
err_msg = err_msg_prefix + 'The following fields are invalid:'
known_fields = set(e.schema.get('properties', {}).keys())
for field in e.instance:
if field not in known_fields:
most_similar_field = difflib.get_close_matches(
field, known_fields, 1)
if most_similar_field:
err_msg += (f'\nInstead of {field!r}, did you mean '
f'{most_similar_field[0]!r}?')
else:
err_msg += f'\nFound unsupported field {field!r}.'
else:
# Example e.json_path value: '$.resources'
err_msg = (err_msg_prefix + e.message +
f'. Check problematic field(s): {e.json_path}')

if err_msg:
with ux_utils.print_exception_no_traceback():
raise ValueError(err_msg)
29 changes: 29 additions & 0 deletions sky/utils/kubernetes_enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Kubernetes enums for SkyPilot."""
import enum


class KubernetesNetworkingMode(enum.Enum):
"""Enum for the different types of networking modes for accessing
jump pods.
"""
NODEPORT = 'nodeport'
PORTFORWARD = 'portforward'

@classmethod
def from_str(cls, mode: str) -> 'KubernetesNetworkingMode':
"""Returns the enum value for the given string."""
if mode.lower() == cls.NODEPORT.value:
return cls.NODEPORT
elif mode.lower() == cls.PORTFORWARD.value:
return cls.PORTFORWARD
else:
raise ValueError(f'Unsupported kubernetes networking mode: '
f'{mode}. The mode must be either '
f'\'{cls.PORTFORWARD.value}\' or '
f'\'{cls.NODEPORT.value}\'. ')


class KubernetesServiceType(enum.Enum):
"""Enum for the different types of services."""
NODEPORT = 'NodePort'
CLUSTERIP = 'ClusterIP'
Loading

0 comments on commit dd9450d

Please sign in to comment.