diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index b45634f00ac8..e06c682bc2c0 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -810,11 +810,13 @@ def write_cluster_config( assert cluster_name is not None excluded_clouds = [] - remote_identity = skypilot_config.get_nested( - (str(cloud).lower(), 'remote_identity'), - schemas.get_default_remote_identity(str(cloud).lower())) - if remote_identity is not None and not isinstance(remote_identity, str): - for profile in remote_identity: + remote_identity_config = skypilot_config.get_nested( + (str(cloud).lower(), 'remote_identity'), None) + remote_identity = schemas.get_default_remote_identity(str(cloud).lower()) + if isinstance(remote_identity_config, str): + remote_identity = remote_identity_config + if isinstance(remote_identity_config, list): + for profile in remote_identity_config: if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]): remote_identity = list(profile.values())[0] break diff --git a/tests/test_yamls/test_aws_config.yaml b/tests/test_yamls/test_aws_config.yaml index 4cda26dc9b7e..047334703c1c 100644 --- a/tests/test_yamls/test_aws_config.yaml +++ b/tests/test_yamls/test_aws_config.yaml @@ -2,7 +2,7 @@ aws: vpc_name: fake-vpc remote_identity: - sky-serve-fake1-*: fake1-skypilot-role - - sky-serve-fake2-*: fake1-skypilot-role + - sky-serve-fake2-*: fake2-skypilot-role security_group_name: - sky-serve-fake1-*: fake-1-sg diff --git a/tests/unit_tests/test_backend_utils.py b/tests/unit_tests/test_backend_utils.py new file mode 100644 index 000000000000..ff786aac5f20 --- /dev/null +++ b/tests/unit_tests/test_backend_utils.py @@ -0,0 +1,117 @@ +import pathlib +from typing import Dict +from unittest.mock import Mock +from unittest.mock import patch + +import pytest + +from sky import clouds +from sky import skypilot_config +from sky.resources import Resources +from sky.resources import resources_utils +from sky.backends import backend_utils + + +@patch.object(skypilot_config, 'CONFIG_PATH', + './tests/test_yamls/test_aws_config.yaml') +@patch.object(skypilot_config, '_dict', None) +@patch.object(skypilot_config, '_loaded_config_path', None) +@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) +@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', + return_value={'fake-acc': 2}) +@patch('sky.clouds.service_catalog.get_image_id_from_tag', + return_value='fake-image') +@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') +@patch('sky.check.get_cloud_credential_file_mounts', return_value='~/.aws/credentials') +@patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', return_value='/tmp/fake/path') +@patch('sky.utils.common_utils.fill_template') +def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None: + skypilot_config._try_load_config() + + cloud = clouds.AWS() + + region = clouds.Region(name='fake-region') + zones = [clouds.Zone(name='fake-zone')] + resource = Resources(cloud=cloud, instance_type='fake-type: 3') + + cluster_config_template = 'aws-ray.yml.j2' + + # test default + backend_utils.write_cluster_config( + to_provision=resource, + num_nodes=2, + cluster_config_template=cluster_config_template, + cluster_name="display", + local_wheel_path=pathlib.Path('/tmp/fake'), + wheel_hash='b1bd84059bc0342f7843fcbe04ab563e', + region=region, + zones=zones, + dryrun=True, + keep_launch_fields_in_existing_config=True + ) + + expected_subset = { + 'instance_type': 'fake-type: 3', + 'custom_resources': '{"fake-acc":2}', + 'region': 'fake-region', + 'zones': 'fake-zone', + 'image_id': 'fake-image', + 'security_group': 'fake-default-sg', + 'security_group_managed_by_skypilot': 'true', + 'vpc_name': 'fake-vpc', + 'remote_identity': 'LOCAL_CREDENTIALS', # remote identity + 'sky_local_path': '/tmp/fake', + 'sky_wheel_hash': 'b1bd84059bc0342f7843fcbe04ab563e', + } + + mock_fill_template.assert_called_once() + assert mock_fill_template.call_args[0][0] == cluster_config_template, "config template incorrect" + assert mock_fill_template.call_args[0][1].items() >= expected_subset.items(), "config fill values incorrect" + + # test using cluster matches regex, top + mock_fill_template.reset_mock() + expected_subset.update({ + 'security_group': 'fake-1-sg', + 'security_group_managed_by_skypilot': 'false', + 'remote_identity': 'fake1-skypilot-role' + }) + backend_utils.write_cluster_config( + to_provision=resource, + num_nodes=2, + cluster_config_template=cluster_config_template, + cluster_name="sky-serve-fake1-1234", + local_wheel_path=pathlib.Path('/tmp/fake'), + wheel_hash='b1bd84059bc0342f7843fcbe04ab563e', + region=region, + zones=zones, + dryrun=True, + keep_launch_fields_in_existing_config=True + ) + + mock_fill_template.assert_called_once() + assert mock_fill_template.call_args[0][0] == cluster_config_template, "config template incorrect" + assert mock_fill_template.call_args[0][1].items() >= expected_subset.items(), "config fill values incorrect" + + # test using cluster matches regex, middle + mock_fill_template.reset_mock() + expected_subset.update({ + 'security_group': 'fake-2-sg', + 'security_group_managed_by_skypilot': 'false', + 'remote_identity': 'fake2-skypilot-role' + }) + backend_utils.write_cluster_config( + to_provision=resource, + num_nodes=2, + cluster_config_template=cluster_config_template, + cluster_name="sky-serve-fake2-1234", + local_wheel_path=pathlib.Path('/tmp/fake'), + wheel_hash='b1bd84059bc0342f7843fcbe04ab563e', + region=region, + zones=zones, + dryrun=True, + keep_launch_fields_in_existing_config=True + ) + + mock_fill_template.assert_called_once() + assert mock_fill_template.call_args[0][0] == cluster_config_template, "config template incorrect" + assert mock_fill_template.call_args[0][1].items() >= expected_subset.items(), "config fill values incorrect" diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 20da1a85ce47..d7643459ef0e 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -95,12 +95,12 @@ def test_kubernetes_labels_resources(): './tests/test_yamls/test_aws_config.yaml') @patch.object(skypilot_config, '_dict', None) @patch.object(skypilot_config, '_loaded_config_path', None) -@patch("sky.clouds.service_catalog.instance_type_exists", return_value=True) -@patch("sky.clouds.service_catalog.get_accelerators_from_instance_type", - return_value={"fake-acc": 2}) -@patch("sky.clouds.service_catalog.get_image_id_from_tag", - return_value="fake-image") -@patch.object(clouds.aws, "DEFAULT_SECURITY_GROUP_NAME", "fake-default-sg") +@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True) +@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type', + return_value={'fake-acc': 2}) +@patch('sky.clouds.service_catalog.get_image_id_from_tag', + return_value='fake-image') +@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg') def test_aws_make_deploy_variables(*mocks) -> None: skypilot_config._try_load_config() @@ -109,7 +109,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: name_on_cloud='cloud') region = clouds.Region(name='fake-region') zones = [clouds.Zone(name='fake-zone')] - resource = Resources(cloud=cloud, instance_type="fake-type: 3") + resource = Resources(cloud=cloud, instance_type='fake-type: 3') config = resource.make_deploy_variables(cluster_name, region, zones, @@ -134,18 +134,18 @@ def test_aws_make_deploy_variables(*mocks) -> None: # test using defaults expected_config = expected_config_base.copy() expected_config.update({ - 'security_group': "fake-default-sg", + 'security_group': 'fake-default-sg', 'security_group_managed_by_skypilot': 'true' }) assert config == expected_config, ('unexpected resource ' 'variables generated') - # test using culuster matches regex, top + # test using cluster matches regex, top cluster_name = resources_utils.ClusterName( display_name='sky-serve-fake1-1234', name_on_cloud='name-on-cloud') expected_config = expected_config_base.copy() expected_config.update({ - 'security_group': "fake-1-sg", + 'security_group': 'fake-1-sg', 'security_group_managed_by_skypilot': 'false' }) config = resource.make_deploy_variables(cluster_name, @@ -155,12 +155,12 @@ def test_aws_make_deploy_variables(*mocks) -> None: assert config == expected_config, ('unexpected resource ' 'variables generated') - # test using culuster matches regex, middle + # test using cluster matches regex, middle cluster_name = resources_utils.ClusterName( display_name='sky-serve-fake2-1234', name_on_cloud='name-on-cloud') expected_config = expected_config_base.copy() expected_config.update({ - 'security_group': "fake-2-sg", + 'security_group': 'fake-2-sg', 'security_group_managed_by_skypilot': 'false' }) config = resource.make_deploy_variables(cluster_name,