From ddbf3b95d4f9b2d8e6f8285408127fe8cb545126 Mon Sep 17 00:00:00 2001 From: Isaac Ong Date: Tue, 12 Sep 2023 21:12:29 -0700 Subject: [PATCH] Allow empty YAML fields (#1890) * Allow empty YAML fields * Properly handle None values in config * Add unit test * Add terminal newline * Clean up test * Add disk tier * Add test to workflow * Add docstr for validation method * Update test * Clean code * Set skip_none enabled by default * Fix workflow format * Small format fix * Fix config name * Remove experimental flag --- .github/workflows/pytest.yml | 1 + sky/backends/backend_utils.py | 15 ++++- sky/backends/onprem_utils.py | 2 +- sky/data/storage.py | 8 ++- sky/resources.py | 64 +++++++++---------- tests/test_yaml_parser.py | 113 ++++++++++++++++++++++++++++++++++ 6 files changed, 163 insertions(+), 40 deletions(-) create mode 100644 tests/test_yaml_parser.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 02375cc8093..ad2b1be6234 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -27,6 +27,7 @@ jobs: - tests/test_storage.py - tests/test_wheels.py - tests/test_spot.py + - tests/test_yaml_parser.py runs-on: ubuntu-latest steps: - name: Checkout repository diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 33d7ab26daa..b3162aa0226 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -2735,12 +2735,23 @@ def stop_handler(signum, frame): raise KeyboardInterrupt(exceptions.SIGTSTP_CODE) -def validate_schema(obj, schema, err_msg_prefix=''): - """Validates an object against a JSON schema. +def validate_schema(obj, schema, err_msg_prefix='', skip_none=True): + """Validates an object against a given JSON schema. + + Args: + obj: The object to validate. + schema: The JSON schema against which to validate the object. + err_msg_prefix: The string to prepend to the error message if + validation fails. + skip_none: If True, removes fields with value None from the object + before validation. This is useful for objects that will never contain + None because yaml.safe_load() loads empty fields as None. Raises: ValueError: if the object does not match the schema. """ + if skip_none: + obj = {k: v for k, v in obj.items() if v is not None} err_msg = None try: validator.SchemaValidator(schema).validate(obj) diff --git a/sky/backends/onprem_utils.py b/sky/backends/onprem_utils.py index d1724fc6785..05f015d7372 100644 --- a/sky/backends/onprem_utils.py +++ b/sky/backends/onprem_utils.py @@ -497,7 +497,7 @@ def check_local_cloud_args(cloud: Optional[str] = None, yaml_config: User's task yaml loaded into a JSON dictionary. """ yaml_cloud = None - if yaml_config is not None and 'resources' in yaml_config: + if yaml_config is not None and yaml_config.get('resources') is not None: yaml_cloud = yaml_config['resources'].get('cloud') if (cluster_name is not None and check_if_local_cloud(cluster_name)): diff --git a/sky/data/storage.py b/sky/data/storage.py index a21a2a7f4df..9dd755a1bba 100644 --- a/sky/data/storage.py +++ b/sky/data/storage.py @@ -872,7 +872,9 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': source = config.pop('source', None) store = config.pop('store', None) mode_str = config.pop('mode', None) - force_delete = config.pop('_force_delete', False) + force_delete = config.pop('_force_delete', None) + if force_delete is None: + force_delete = False if isinstance(mode_str, str): # Make mode case insensitive, if specified @@ -880,7 +882,9 @@ def from_yaml_config(cls, config: Dict[str, Any]) -> 'Storage': else: # Make sure this keeps the same as the default mode in __init__ mode = StorageMode.MOUNT - persistent = config.pop('persistent', True) + persistent = config.pop('persistent', None) + if persistent is None: + persistent = True assert not config, f'Invalid storage args: {config.keys()}' diff --git a/sky/resources.py b/sky/resources.py index 8c139f07c44..57bb16d576d 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -1077,42 +1077,36 @@ def from_yaml_config(cls, config: Optional[Dict[str, str]]) -> 'Resources': 'Invalid resources YAML: ') resources_fields = {} - if config.get('cloud') is not None: - resources_fields['cloud'] = clouds.CLOUD_REGISTRY.from_str( - config.pop('cloud')) - if config.get('instance_type') is not None: - resources_fields['instance_type'] = config.pop('instance_type') - if config.get('cpus') is not None: - resources_fields['cpus'] = str(config.pop('cpus')) - if config.get('memory') is not None: - resources_fields['memory'] = str(config.pop('memory')) - if config.get('accelerators') is not None: - resources_fields['accelerators'] = config.pop('accelerators') - if config.get('accelerator_args') is not None: + resources_fields['cloud'] = clouds.CLOUD_REGISTRY.from_str( + config.pop('cloud', None)) + resources_fields['instance_type'] = config.pop('instance_type', None) + resources_fields['cpus'] = config.pop('cpus', None) + resources_fields['memory'] = config.pop('memory', None) + resources_fields['accelerators'] = config.pop('accelerators', None) + resources_fields['accelerator_args'] = config.pop( + 'accelerator_args', None) + resources_fields['use_spot'] = config.pop('use_spot', None) + resources_fields['spot_recovery'] = config.pop('spot_recovery', None) + resources_fields['disk_size'] = config.pop('disk_size', None) + resources_fields['region'] = config.pop('region', None) + resources_fields['zone'] = config.pop('zone', None) + resources_fields['image_id'] = config.pop('image_id', None) + resources_fields['disk_tier'] = config.pop('disk_tier', None) + resources_fields['ports'] = config.pop('ports', None) + resources_fields['_docker_login_config'] = config.pop( + '_docker_login_config', None) + resources_fields['_is_image_managed'] = config.pop( + '_is_image_managed', None) + + if resources_fields['cpus'] is not None: + resources_fields['cpus'] = str(resources_fields['cpus']) + if resources_fields['memory'] is not None: + resources_fields['memory'] = str(resources_fields['memory']) + if resources_fields['accelerator_args'] is not None: resources_fields['accelerator_args'] = dict( - config.pop('accelerator_args')) - if config.get('use_spot') is not None: - resources_fields['use_spot'] = config.pop('use_spot') - if config.get('spot_recovery') is not None: - resources_fields['spot_recovery'] = config.pop('spot_recovery') - if config.get('disk_size') is not None: - resources_fields['disk_size'] = int(config.pop('disk_size')) - if config.get('region') is not None: - resources_fields['region'] = config.pop('region') - if config.get('zone') is not None: - resources_fields['zone'] = config.pop('zone') - if config.get('image_id') is not None: - resources_fields['image_id'] = config.pop('image_id') - if config.get('disk_tier') is not None: - resources_fields['disk_tier'] = config.pop('disk_tier') - if config.get('ports') is not None: - resources_fields['ports'] = config.pop('ports') - if config.get('_docker_login_config') is not None: - resources_fields['_docker_login_config'] = config.pop( - '_docker_login_config') - if config.get('_is_image_managed') is not None: - resources_fields['_is_image_managed'] = config.pop( - '_is_image_managed') + resources_fields['accelerator_args']) + if resources_fields['disk_size'] is not None: + resources_fields['disk_size'] = int(resources_fields['disk_size']) assert not config, f'Invalid resource args: {config.keys()}' return Resources(**resources_fields) diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py new file mode 100644 index 00000000000..3534704a041 --- /dev/null +++ b/tests/test_yaml_parser.py @@ -0,0 +1,113 @@ +import pathlib +import textwrap + +import pytest + +from sky.task import Task + + +def _create_config_file(config: str, tmp_path: pathlib.Path) -> str: + config_path = tmp_path / 'config.yaml' + config_path.open('w').write(config) + return config_path + + +def test_empty_fields_task(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + name: task + resources: + workdir: examples/ + file_mounts: + setup: echo "Running setup." + num_nodes: + run: + # commented out, empty run + """), tmp_path) + task = Task.from_yaml(config_path) + + assert task.name == 'task' + assert list(task.resources)[0].is_empty() + assert task.file_mounts is None + assert task.run is None + assert task.setup == 'echo "Running setup."' + assert task.num_nodes == 1 + assert task.workdir == 'examples/' + + +def test_invalid_fields_task(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + name: task + + not_a_valid_field: + """), tmp_path) + with pytest.raises(AssertionError) as e: + Task.from_yaml(config_path) + assert 'Invalid task args' in e.value.args[0] + + +def test_empty_fields_resources(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + resources: + cloud: + region: + accelerators: V100:1 + disk_size: + use_spot: + cpus: 32 + """), tmp_path) + task = Task.from_yaml(config_path) + + resources = list(task.resources)[0] + assert resources.cloud is None + assert resources.region is None + assert resources.accelerators == {'V100': 1} + assert resources.disk_size is 256 + assert resources.use_spot is False + assert resources.cpus == '32' + + +def test_invalid_fields_resources(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + resources: + cloud: aws + not_a_valid_field: + """), tmp_path) + with pytest.raises(AssertionError) as e: + Task.from_yaml(config_path) + assert 'Invalid resource args' in e.value.args[0] + + +def test_empty_fields_storage(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + file_mounts: + /mystorage: + name: sky-dataset + source: + store: + persistent: + """), tmp_path) + task = Task.from_yaml(config_path) + + storage = task.storage_mounts['/mystorage'] + assert storage.name == 'sky-dataset' + assert storage.source is None + assert len(storage.stores) == 0 + assert storage.persistent is True + + +def test_invalid_fields_storage(tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + file_mounts: + /datasets-storage: + name: sky-dataset + not_a_valid_field: + """), tmp_path) + with pytest.raises(AssertionError) as e: + Task.from_yaml(config_path) + assert 'Invalid storage args' in e.value.args[0]