Skip to content

Commit

Permalink
fix: bug in remote identity and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JGSweets committed Jul 4, 2024
1 parent b925cf1 commit dcb7c2a
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 18 deletions.
12 changes: 7 additions & 5 deletions sky/backends/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,11 +810,13 @@ def write_cluster_config(

assert cluster_name is not None
excluded_clouds = []
remote_identity = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'),
schemas.get_default_remote_identity(str(cloud).lower()))
if remote_identity is not None and not isinstance(remote_identity, str):
for profile in remote_identity:
remote_identity_config = skypilot_config.get_nested(
(str(cloud).lower(), 'remote_identity'), None)
remote_identity = schemas.get_default_remote_identity(str(cloud).lower())
if isinstance(remote_identity_config, str):
remote_identity = remote_identity_config
if isinstance(remote_identity_config, list):
for profile in remote_identity_config:
if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]):
remote_identity = list(profile.values())[0]
break
Expand Down
2 changes: 1 addition & 1 deletion tests/test_yamls/test_aws_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ aws:
vpc_name: fake-vpc
remote_identity:
- sky-serve-fake1-*: fake1-skypilot-role
- sky-serve-fake2-*: fake1-skypilot-role
- sky-serve-fake2-*: fake2-skypilot-role

security_group_name:
- sky-serve-fake1-*: fake-1-sg
Expand Down
117 changes: 117 additions & 0 deletions tests/unit_tests/test_backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pathlib
from typing import Dict
from unittest.mock import Mock
from unittest.mock import patch

import pytest

from sky import clouds
from sky import skypilot_config
from sky.resources import Resources
from sky.resources import resources_utils
from sky.backends import backend_utils


@patch.object(skypilot_config, 'CONFIG_PATH',
'./tests/test_yamls/test_aws_config.yaml')
@patch.object(skypilot_config, '_dict', None)
@patch.object(skypilot_config, '_loaded_config_path', None)
@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True)
@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type',
return_value={'fake-acc': 2})
@patch('sky.clouds.service_catalog.get_image_id_from_tag',
return_value='fake-image')
@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg')
@patch('sky.check.get_cloud_credential_file_mounts', return_value='~/.aws/credentials')
@patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', return_value='/tmp/fake/path')
@patch('sky.utils.common_utils.fill_template')
def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None:
skypilot_config._try_load_config()

cloud = clouds.AWS()

region = clouds.Region(name='fake-region')
zones = [clouds.Zone(name='fake-zone')]
resource = Resources(cloud=cloud, instance_type='fake-type: 3')

cluster_config_template = 'aws-ray.yml.j2'

# test default
backend_utils.write_cluster_config(
to_provision=resource,
num_nodes=2,
cluster_config_template=cluster_config_template,
cluster_name="display",
local_wheel_path=pathlib.Path('/tmp/fake'),
wheel_hash='b1bd84059bc0342f7843fcbe04ab563e',
region=region,
zones=zones,
dryrun=True,
keep_launch_fields_in_existing_config=True
)

expected_subset = {
'instance_type': 'fake-type: 3',
'custom_resources': '{"fake-acc":2}',
'region': 'fake-region',
'zones': 'fake-zone',
'image_id': 'fake-image',
'security_group': 'fake-default-sg',
'security_group_managed_by_skypilot': 'true',
'vpc_name': 'fake-vpc',
'remote_identity': 'LOCAL_CREDENTIALS', # remote identity
'sky_local_path': '/tmp/fake',
'sky_wheel_hash': 'b1bd84059bc0342f7843fcbe04ab563e',
}

mock_fill_template.assert_called_once()
assert mock_fill_template.call_args[0][0] == cluster_config_template, "config template incorrect"
assert mock_fill_template.call_args[0][1].items() >= expected_subset.items(), "config fill values incorrect"

# test using cluster matches regex, top
mock_fill_template.reset_mock()
expected_subset.update({
'security_group': 'fake-1-sg',
'security_group_managed_by_skypilot': 'false',
'remote_identity': 'fake1-skypilot-role'
})
backend_utils.write_cluster_config(
to_provision=resource,
num_nodes=2,
cluster_config_template=cluster_config_template,
cluster_name="sky-serve-fake1-1234",
local_wheel_path=pathlib.Path('/tmp/fake'),
wheel_hash='b1bd84059bc0342f7843fcbe04ab563e',
region=region,
zones=zones,
dryrun=True,
keep_launch_fields_in_existing_config=True
)

mock_fill_template.assert_called_once()
assert mock_fill_template.call_args[0][0] == cluster_config_template, "config template incorrect"
assert mock_fill_template.call_args[0][1].items() >= expected_subset.items(), "config fill values incorrect"

# test using cluster matches regex, middle
mock_fill_template.reset_mock()
expected_subset.update({
'security_group': 'fake-2-sg',
'security_group_managed_by_skypilot': 'false',
'remote_identity': 'fake2-skypilot-role'
})
backend_utils.write_cluster_config(
to_provision=resource,
num_nodes=2,
cluster_config_template=cluster_config_template,
cluster_name="sky-serve-fake2-1234",
local_wheel_path=pathlib.Path('/tmp/fake'),
wheel_hash='b1bd84059bc0342f7843fcbe04ab563e',
region=region,
zones=zones,
dryrun=True,
keep_launch_fields_in_existing_config=True
)

mock_fill_template.assert_called_once()
assert mock_fill_template.call_args[0][0] == cluster_config_template, "config template incorrect"
assert mock_fill_template.call_args[0][1].items() >= expected_subset.items(), "config fill values incorrect"
24 changes: 12 additions & 12 deletions tests/unit_tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def test_kubernetes_labels_resources():
'./tests/test_yamls/test_aws_config.yaml')
@patch.object(skypilot_config, '_dict', None)
@patch.object(skypilot_config, '_loaded_config_path', None)
@patch("sky.clouds.service_catalog.instance_type_exists", return_value=True)
@patch("sky.clouds.service_catalog.get_accelerators_from_instance_type",
return_value={"fake-acc": 2})
@patch("sky.clouds.service_catalog.get_image_id_from_tag",
return_value="fake-image")
@patch.object(clouds.aws, "DEFAULT_SECURITY_GROUP_NAME", "fake-default-sg")
@patch('sky.clouds.service_catalog.instance_type_exists', return_value=True)
@patch('sky.clouds.service_catalog.get_accelerators_from_instance_type',
return_value={'fake-acc': 2})
@patch('sky.clouds.service_catalog.get_image_id_from_tag',
return_value='fake-image')
@patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg')
def test_aws_make_deploy_variables(*mocks) -> None:
skypilot_config._try_load_config()

Expand All @@ -109,7 +109,7 @@ def test_aws_make_deploy_variables(*mocks) -> None:
name_on_cloud='cloud')
region = clouds.Region(name='fake-region')
zones = [clouds.Zone(name='fake-zone')]
resource = Resources(cloud=cloud, instance_type="fake-type: 3")
resource = Resources(cloud=cloud, instance_type='fake-type: 3')
config = resource.make_deploy_variables(cluster_name,
region,
zones,
Expand All @@ -134,18 +134,18 @@ def test_aws_make_deploy_variables(*mocks) -> None:
# test using defaults
expected_config = expected_config_base.copy()
expected_config.update({
'security_group': "fake-default-sg",
'security_group': 'fake-default-sg',
'security_group_managed_by_skypilot': 'true'
})
assert config == expected_config, ('unexpected resource '
'variables generated')

# test using culuster matches regex, top
# test using cluster matches regex, top
cluster_name = resources_utils.ClusterName(
display_name='sky-serve-fake1-1234', name_on_cloud='name-on-cloud')
expected_config = expected_config_base.copy()
expected_config.update({
'security_group': "fake-1-sg",
'security_group': 'fake-1-sg',
'security_group_managed_by_skypilot': 'false'
})
config = resource.make_deploy_variables(cluster_name,
Expand All @@ -155,12 +155,12 @@ def test_aws_make_deploy_variables(*mocks) -> None:
assert config == expected_config, ('unexpected resource '
'variables generated')

# test using culuster matches regex, middle
# test using cluster matches regex, middle
cluster_name = resources_utils.ClusterName(
display_name='sky-serve-fake2-1234', name_on_cloud='name-on-cloud')
expected_config = expected_config_base.copy()
expected_config.update({
'security_group': "fake-2-sg",
'security_group': 'fake-2-sg',
'security_group_managed_by_skypilot': 'false'
})
config = resource.make_deploy_variables(cluster_name,
Expand Down

0 comments on commit dcb7c2a

Please sign in to comment.