From 130a5aa7d011c804de26e826112cf363e4804dc5 Mon Sep 17 00:00:00 2001 From: Carmen Chow Date: Mon, 3 Jun 2024 14:05:00 -0700 Subject: [PATCH] Update spark-run to no longer override the aws-profile flag with the service flag --- service_configuration_lib/spark_config.py | 9 +++++++++ tests/spark_config_test.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index b5f8bc3..72f74d4 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -148,6 +148,8 @@ def get_aws_credentials( session['Credentials']['SecretAccessKey'], session['Credentials']['SessionToken'], ) + elif profile_name: + return use_aws_profile(profile_name=profile_name, session=session) elif service != DEFAULT_SPARK_SERVICE: service_credentials_path = os.path.join(AWS_CREDENTIALS_DIR, f'{service}.yaml') if os.path.exists(service_credentials_path): @@ -158,6 +160,13 @@ def get_aws_credentials( 'Falling back to user credentials.', ) + return use_aws_profile(session=session) + + +def use_aws_profile( + profile_name: str = 'default', + session: Optional[boto3.Session] = None, +) -> Tuple[Optional[str], Optional[str], Optional[str]]: session = session or Session(profile_name=profile_name) creds = session.get_credentials() return ( diff --git a/tests/spark_config_test.py b/tests/spark_config_test.py index cfc75ef..60150ed 100644 --- a/tests/spark_config_test.py +++ b/tests/spark_config_test.py @@ -100,9 +100,21 @@ def test_use_aws_credentials_json(self, tmpdir): fp.write(json.dumps({'accessKeyId': self.access_key, 'secretAccessKey': self.secret_key})) assert spark_config.get_aws_credentials(aws_credentials_json=str(fp)) == self.expected_creds + def test_use_profile_and_service(self, mock_session): + profile = 'test_profile' + service = 'test_service' + assert spark_config.get_aws_credentials(profile_name=profile, service=service) == self.expected_temp_creds + def test_use_profile(self, mock_session): assert spark_config.get_aws_credentials(profile_name='test_profile') == self.expected_temp_creds + def test_use_default_profile(self, mock_session): + assert spark_config.get_aws_credentials(service=spark_config.DEFAULT_SPARK_SERVICE) == self.expected_temp_creds + + def test_no_service_specified(self, mock_session): + # should default to using the `default` user profile if no other credentials specified + assert spark_config.get_aws_credentials() == self.expected_temp_creds + @pytest.fixture def mock_client(self): mock_client = mock.Mock()