Skip to content

Commit

Permalink
Merge pull request #144 from Yelp/spark_run_aws_profile_behaviour
Browse files Browse the repository at this point in the history
Update spark-run to no longer override --aws-profile
  • Loading branch information
choww authored Jun 4, 2024
2 parents 8fb7b30 + 130a5aa commit 36407dc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
9 changes: 9 additions & 0 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def get_aws_credentials(
session['Credentials']['SecretAccessKey'],
session['Credentials']['SessionToken'],
)
elif profile_name:
return use_aws_profile(profile_name=profile_name, session=session)
elif service != DEFAULT_SPARK_SERVICE:
service_credentials_path = os.path.join(AWS_CREDENTIALS_DIR, f'{service}.yaml')
if os.path.exists(service_credentials_path):
Expand All @@ -158,6 +160,13 @@ def get_aws_credentials(
'Falling back to user credentials.',
)

return use_aws_profile(session=session)


def use_aws_profile(
profile_name: str = 'default',
session: Optional[boto3.Session] = None,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
session = session or Session(profile_name=profile_name)
creds = session.get_credentials()
return (
Expand Down
12 changes: 12 additions & 0 deletions tests/spark_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,21 @@ def test_use_aws_credentials_json(self, tmpdir):
fp.write(json.dumps({'accessKeyId': self.access_key, 'secretAccessKey': self.secret_key}))
assert spark_config.get_aws_credentials(aws_credentials_json=str(fp)) == self.expected_creds

def test_use_profile_and_service(self, mock_session):
profile = 'test_profile'
service = 'test_service'
assert spark_config.get_aws_credentials(profile_name=profile, service=service) == self.expected_temp_creds

def test_use_profile(self, mock_session):
assert spark_config.get_aws_credentials(profile_name='test_profile') == self.expected_temp_creds

def test_use_default_profile(self, mock_session):
assert spark_config.get_aws_credentials(service=spark_config.DEFAULT_SPARK_SERVICE) == self.expected_temp_creds

def test_no_service_specified(self, mock_session):
# should default to using the `default` user profile if no other credentials specified
assert spark_config.get_aws_credentials() == self.expected_temp_creds

@pytest.fixture
def mock_client(self):
mock_client = mock.Mock()
Expand Down

0 comments on commit 36407dc

Please sign in to comment.