diff --git a/sky/task.py b/sky/task.py index cf26e13717a..cebc616dc6d 100644 --- a/sky/task.py +++ b/sky/task.py @@ -393,6 +393,11 @@ def from_yaml_config( config['service'] = _fill_in_env_vars(config['service'], config.get('envs', {})) + # Fill in any Task.envs into workdir + if config.get('workdir') is not None: + config['workdir'] = _fill_in_env_vars(config['workdir'], + config.get('envs', {})) + task = Task( config.pop('name', None), run=config.pop('run', None), diff --git a/tests/test_yaml_parser.py b/tests/test_yaml_parser.py index 1453cfe1620..7d304b60633 100644 --- a/tests/test_yaml_parser.py +++ b/tests/test_yaml_parser.py @@ -146,3 +146,14 @@ def test_invalid_empty_envs(tmp_path): with pytest.raises(ValueError) as e: Task.from_yaml(config_path) assert 'Environment variable \'env_key2\' is None.' in e.value.args[0] + + +def test_replace_envs_in_workdir(tmpdir, tmp_path): + config_path = _create_config_file( + textwrap.dedent(f"""\ + envs: + env_key1: {tmpdir} + workdir: $env_key1 + """), tmp_path) + task = Task.from_yaml(config_path) + assert task.workdir == tmpdir