Skip to content

Commit

Permalink
Allow empty YAML fields (#1890)
Browse files Browse the repository at this point in the history
* Allow empty YAML fields

* Properly handle None values in config

* Add unit test

* Add terminal newline

* Clean up test

* Add disk tier

* Add test to workflow

* Add docstr for validation method

* Update test

* Clean code

* Set skip_none enabled by default

* Fix workflow format

* Small format fix

* Fix config name

* Remove experimental flag
  • Loading branch information
iojw authored Sep 13, 2023
1 parent 9e115c9 commit ddbf3b9
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 40 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- tests/test_storage.py
- tests/test_wheels.py
- tests/test_spot.py
- tests/test_yaml_parser.py
runs-on: ubuntu-latest
steps:
- name: Checkout repository
Expand Down
15 changes: 13 additions & 2 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2735,12 +2735,23 @@ def stop_handler(signum, frame):
raise KeyboardInterrupt(exceptions.SIGTSTP_CODE)


def validate_schema(obj, schema, err_msg_prefix=''):
"""Validates an object against a JSON schema.
def validate_schema(obj, schema, err_msg_prefix='', skip_none=True):
"""Validates an object against a given JSON schema.
Args:
obj: The object to validate.
schema: The JSON schema against which to validate the object.
err_msg_prefix: The string to prepend to the error message if
validation fails.
skip_none: If True, removes fields with value None from the object
before validation. This is useful for objects that will never contain
None because yaml.safe_load() loads empty fields as None.
Raises:
ValueError: if the object does not match the schema.
"""
if skip_none:
obj = {k: v for k, v in obj.items() if v is not None}
err_msg = None
try:
validator.SchemaValidator(schema).validate(obj)
Expand Down
2 changes: 1 addition & 1 deletion sky/backends/onprem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def check_local_cloud_args(cloud: Optional[str] = None,
yaml_config: User's task yaml loaded into a JSON dictionary.
"""
yaml_cloud = None
if yaml_config is not None and 'resources' in yaml_config:
if yaml_config is not None and yaml_config.get('resources') is not None:
yaml_cloud = yaml_config['resources'].get('cloud')

if (cluster_name is not None and check_if_local_cloud(cluster_name)):
Expand Down
8 changes: 6 additions & 2 deletions sky/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,15 +872,19 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage':
source = config.pop('source', None)
store = config.pop('store', None)
mode_str = config.pop('mode', None)
force_delete = config.pop('_force_delete', False)
force_delete = config.pop('_force_delete', None)
if force_delete is None:
force_delete = False

if isinstance(mode_str, str):
# Make mode case insensitive, if specified
mode = StorageMode(mode_str.upper())
else:
# Make sure this keeps the same as the default mode in __init__
mode = StorageMode.MOUNT
persistent = config.pop('persistent', True)
persistent = config.pop('persistent', None)
if persistent is None:
persistent = True

assert not config, f'Invalid storage args: {config.keys()}'

Expand Down
64 changes: 29 additions & 35 deletions sky/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,42 +1077,36 @@ def from_yaml_config(cls, config: Optional[Dict[str, str]]) -> 'Resources':
'Invalid resources YAML: ')

resources_fields = {}
if config.get('cloud') is not None:
resources_fields['cloud'] = clouds.CLOUD_REGISTRY.from_str(
config.pop('cloud'))
if config.get('instance_type') is not None:
resources_fields['instance_type'] = config.pop('instance_type')
if config.get('cpus') is not None:
resources_fields['cpus'] = str(config.pop('cpus'))
if config.get('memory') is not None:
resources_fields['memory'] = str(config.pop('memory'))
if config.get('accelerators') is not None:
resources_fields['accelerators'] = config.pop('accelerators')
if config.get('accelerator_args') is not None:
resources_fields['cloud'] = clouds.CLOUD_REGISTRY.from_str(
config.pop('cloud', None))
resources_fields['instance_type'] = config.pop('instance_type', None)
resources_fields['cpus'] = config.pop('cpus', None)
resources_fields['memory'] = config.pop('memory', None)
resources_fields['accelerators'] = config.pop('accelerators', None)
resources_fields['accelerator_args'] = config.pop(
'accelerator_args', None)
resources_fields['use_spot'] = config.pop('use_spot', None)
resources_fields['spot_recovery'] = config.pop('spot_recovery', None)
resources_fields['disk_size'] = config.pop('disk_size', None)
resources_fields['region'] = config.pop('region', None)
resources_fields['zone'] = config.pop('zone', None)
resources_fields['image_id'] = config.pop('image_id', None)
resources_fields['disk_tier'] = config.pop('disk_tier', None)
resources_fields['ports'] = config.pop('ports', None)
resources_fields['_docker_login_config'] = config.pop(
'_docker_login_config', None)
resources_fields['_is_image_managed'] = config.pop(
'_is_image_managed', None)

if resources_fields['cpus'] is not None:
resources_fields['cpus'] = str(resources_fields['cpus'])
if resources_fields['memory'] is not None:
resources_fields['memory'] = str(resources_fields['memory'])
if resources_fields['accelerator_args'] is not None:
resources_fields['accelerator_args'] = dict(
config.pop('accelerator_args'))
if config.get('use_spot') is not None:
resources_fields['use_spot'] = config.pop('use_spot')
if config.get('spot_recovery') is not None:
resources_fields['spot_recovery'] = config.pop('spot_recovery')
if config.get('disk_size') is not None:
resources_fields['disk_size'] = int(config.pop('disk_size'))
if config.get('region') is not None:
resources_fields['region'] = config.pop('region')
if config.get('zone') is not None:
resources_fields['zone'] = config.pop('zone')
if config.get('image_id') is not None:
resources_fields['image_id'] = config.pop('image_id')
if config.get('disk_tier') is not None:
resources_fields['disk_tier'] = config.pop('disk_tier')
if config.get('ports') is not None:
resources_fields['ports'] = config.pop('ports')
if config.get('_docker_login_config') is not None:
resources_fields['_docker_login_config'] = config.pop(
'_docker_login_config')
if config.get('_is_image_managed') is not None:
resources_fields['_is_image_managed'] = config.pop(
'_is_image_managed')
resources_fields['accelerator_args'])
if resources_fields['disk_size'] is not None:
resources_fields['disk_size'] = int(resources_fields['disk_size'])

assert not config, f'Invalid resource args: {config.keys()}'
return Resources(**resources_fields)
Expand Down
113 changes: 113 additions & 0 deletions tests/test_yaml_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pathlib
import textwrap

import pytest

from sky.task import Task


def _create_config_file(config: str, tmp_path: pathlib.Path) -> str:
config_path = tmp_path / 'config.yaml'
config_path.open('w').write(config)
return config_path


def test_empty_fields_task(tmp_path):
config_path = _create_config_file(
textwrap.dedent(f"""\
name: task
resources:
workdir: examples/
file_mounts:
setup: echo "Running setup."
num_nodes:
run:
# commented out, empty run
"""), tmp_path)
task = Task.from_yaml(config_path)

assert task.name == 'task'
assert list(task.resources)[0].is_empty()
assert task.file_mounts is None
assert task.run is None
assert task.setup == 'echo "Running setup."'
assert task.num_nodes == 1
assert task.workdir == 'examples/'


def test_invalid_fields_task(tmp_path):
config_path = _create_config_file(
textwrap.dedent(f"""\
name: task
not_a_valid_field:
"""), tmp_path)
with pytest.raises(AssertionError) as e:
Task.from_yaml(config_path)
assert 'Invalid task args' in e.value.args[0]


def test_empty_fields_resources(tmp_path):
config_path = _create_config_file(
textwrap.dedent(f"""\
resources:
cloud:
region:
accelerators: V100:1
disk_size:
use_spot:
cpus: 32
"""), tmp_path)
task = Task.from_yaml(config_path)

resources = list(task.resources)[0]
assert resources.cloud is None
assert resources.region is None
assert resources.accelerators == {'V100': 1}
assert resources.disk_size is 256
assert resources.use_spot is False
assert resources.cpus == '32'


def test_invalid_fields_resources(tmp_path):
config_path = _create_config_file(
textwrap.dedent(f"""\
resources:
cloud: aws
not_a_valid_field:
"""), tmp_path)
with pytest.raises(AssertionError) as e:
Task.from_yaml(config_path)
assert 'Invalid resource args' in e.value.args[0]


def test_empty_fields_storage(tmp_path):
config_path = _create_config_file(
textwrap.dedent(f"""\
file_mounts:
/mystorage:
name: sky-dataset
source:
store:
persistent:
"""), tmp_path)
task = Task.from_yaml(config_path)

storage = task.storage_mounts['/mystorage']
assert storage.name == 'sky-dataset'
assert storage.source is None
assert len(storage.stores) == 0
assert storage.persistent is True


def test_invalid_fields_storage(tmp_path):
config_path = _create_config_file(
textwrap.dedent(f"""\
file_mounts:
/datasets-storage:
name: sky-dataset
not_a_valid_field:
"""), tmp_path)
with pytest.raises(AssertionError) as e:
Task.from_yaml(config_path)
assert 'Invalid storage args' in e.value.args[0]

0 comments on commit ddbf3b9

Please sign in to comment.