diff --git a/.gitignore b/.gitignore index e968c3ad323..20bb0ab57fb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ Release *# *.iml tags +.vs .vscode # CI Artifacts diff --git a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h index 4b90bb01ec7..5790efddb33 100644 --- a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h +++ b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h @@ -52,8 +52,47 @@ namespace Aws */ STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration = std::chrono::minutes(60)); + /** + * Use the provided profile name from the shared configuration file and a custom STS client. + * + * @param profileName The name of the profile in the shared configuration file. + * @param duration The duration, in minutes, of the role session, after which the credentials are expired. + * The value can range from 15 minutes up to the maximum session duration setting for the role. By default, + * the duration is set to 1 hour. + * Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That + * ensures the credentials do not expire between the time they're checked and the time they're returned to + * the user. + * If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only + * when they expire. + * @param stsClientFactory A factory function that creates an STSClient with specific credentials. + * Using the overload where the function returns a shared_ptr is preferred. + * + */ STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function &stsClientFactory); + /** + * Use the provided profile name from the shared configuration file and a custom STS client. + * + * @param profileName The name of the profile in the shared configuration file. + * @param duration The duration, in minutes, of the role session, after which the credentials are expired. + * The value can range from 15 minutes up to the maximum session duration setting for the role. By default, + * the duration is set to 1 hour. + * Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That + * ensures the credentials do not expire between the time they're checked and the time they're returned to + * the user. + * If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only + * when they expire. + * @param stsClientFactory A factory function that creates an STSClient with specific credentials. + * + */ + STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function(const AWSCredentials&)> &stsClientFactory); + + /** + * Compatibility constructor to assist with overload resolution when passing nullptr for the client factory. + * + */ + STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t); + /** * Fetches the credentials set from STS following the rules defined in the shared configuration file. */ @@ -67,14 +106,16 @@ namespace Aws * Returns the assumed role credentials or empty credentials on error. */ AWSCredentials GetCredentialsFromSTS(const AWSCredentials& credentials, const Aws::String& roleARN); + AWSCredentials GetCredentialsFromWebIdentity(const Config::Profile& profile); private: AWSCredentials GetCredentialsFromSTSInternal(const Aws::String& roleArn, Aws::STS::STSClient* client); + AWSCredentials GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client); Aws::String m_profileName; AWSCredentials m_credentials; const std::chrono::minutes m_duration; const std::chrono::milliseconds m_reloadFrequency; - std::function m_stsClientFactory; + std::function(const AWSCredentials&)> m_stsClientFactory; }; } } diff --git a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp index fd82b678fba..1908e522bb3 100644 --- a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp @@ -5,11 +5,13 @@ #include #include +#include #include #include #include #include +#include #include using namespace Aws; @@ -17,6 +19,12 @@ using namespace Aws::Auth; constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider"; +template +struct NoOpDeleter +{ + void operator()(T*) {} +}; + STSProfileCredentialsProvider::STSProfileCredentialsProvider() : STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/) { @@ -27,8 +35,24 @@ STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& { } +STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t) + : m_profileName(profileName), + m_duration(duration), + m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), + m_stsClientFactory(nullptr) +{ +} + STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function &stsClientFactory) : m_profileName(profileName), + m_duration(duration), + m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), + m_stsClientFactory([=](const auto& credentials) {return std::shared_ptr(stsClientFactory(credentials), NoOpDeleter()); }) +{ +} + +STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function (const AWSCredentials&)>& stsClientFactory) + : m_profileName(profileName), m_duration(duration), m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), m_stsClientFactory(stsClientFactory) @@ -66,25 +90,27 @@ enum class ProfileState Process, SourceProfile, SelfReferencing, // special case of SourceProfile. + RoleARNWebIdentity }; /* * A valid profile can be in one of the following states. Any other state is considered invalid. - +---------+-----------+-----------+--------------+ -| | | | | -| Role | Source | Process | Static | -| ARN | Profile | | Credentials | -+------------------------------------------------+ -| | | | | -| false | false | false | TRUE | -| | | | | -| false | false | TRUE | false | -| | | | | -| TRUE | TRUE | false | false | -| | | | | -| TRUE | TRUE | false | TRUE | -| | | | | -+---------+-----------+-----------+--------------+ ++---------+-----------+-----------+--------------+------------+ +| | | | | | +| Role | Source | Process | Static | Web | +| ARN | Profile | | Credentials | Identity | ++------------------------------------------------+------------+ +| | | | | | +| false | false | false | TRUE | false | +| | | | | | +| false | false | TRUE | false | false | +| | | | | | +| TRUE | TRUE | false | false | false | +| | | | | | +| TRUE | TRUE | false | TRUE | false | +| | | | | | +| TRUE | false | false | false | TRUE | ++---------+-----------+-----------+--------------+------------+ */ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLevelProfile) @@ -93,6 +119,7 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe constexpr int PROCESS_CREDENTIALS = 2; constexpr int SOURCE_PROFILE = 4; constexpr int ROLE_ARN = 8; + constexpr int WEB_IDENTITY_TOKEN_FILE = 16; int state = 0; @@ -116,6 +143,11 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe state += ROLE_ARN; } + if (!profile.GetValue("web_identity_token_file").empty()) + { + state += WEB_IDENTITY_TOKEN_FILE; + } + if (topLevelProfile) { switch(state) @@ -133,6 +165,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe } // source-profile over-rule static credentials in top-level profiles (except when self-referencing) return ProfileState::SourceProfile; + case 24: // role arn && web identity + return ProfileState::RoleARNWebIdentity; default: // All other cases are considered malformed configuration. return ProfileState::Invalid; @@ -154,6 +188,8 @@ static ProfileState CheckProfile(const Aws::Config::Profile& profile, bool topLe return ProfileState::SelfReferencing; } return ProfileState::Static; // static credentials over-rule source-profile (except when self-referencing) + case 24: // role arn && web identity + return ProfileState::RoleARNWebIdentity; default: // All other cases are considered malformed configuration. return ProfileState::Invalid; @@ -280,10 +316,14 @@ void STSProfileCredentialsProvider::Reload() while (sourceProfiles.size() > 1) { - const auto profile = sourceProfiles.back()->second; + const auto& profile = sourceProfiles.back()->second; sourceProfiles.pop_back(); AWSCredentials stsCreds; - if (profile.GetCredentialProcess().empty()) + if (CheckProfile(profile, false /*topLevelProfile*/) == ProfileState::RoleARNWebIdentity) + { + stsCreds = GetCredentialsFromWebIdentity(profile); + } + else if (profile.GetCredentialProcess().empty()) { assert(!profile.GetCredentials().IsEmpty()); stsCreds = profile.GetCredentials(); @@ -294,7 +334,7 @@ void STSProfileCredentialsProvider::Reload() } // get the role arn from the profile at the top of the stack (which hasn't been popped out yet) - const auto arn = sourceProfiles.back()->second.GetRoleArn(); + const auto& arn = sourceProfiles.back()->second.GetRoleArn(); const auto& assumedCreds = GetCredentialsFromSTS(stsCreds, arn); sourceProfiles.back()->second.SetCredentials(assumedCreds); } @@ -337,9 +377,68 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre { using namespace Aws::STS::Model; if (m_stsClientFactory) { - return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials)); + auto client = m_stsClientFactory(credentials); + return GetCredentialsFromSTSInternal(roleArn, client.get()); } Aws::STS::STSClient stsClient {credentials}; return GetCredentialsFromSTSInternal(roleArn, &stsClient); } + +AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentityInternal(const Config::Profile& profile, Aws::STS::STSClient* client) +{ + Aws::String roleSessionName = profile.GetValue("role_session_name"); + if (roleSessionName.empty()) + { + roleSessionName = Aws::Utils::UUID::PseudoRandomUUID(); + } + + Aws::String token; + { + auto& tokenPath = profile.GetValue("web_identity_token_file"); + Aws::IFStream tokenFile(tokenPath); + if (tokenFile) { + token = Aws::String( + (std::istreambuf_iterator(tokenFile)), + std::istreambuf_iterator()); + } + else { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Can't open token file: " << tokenPath); + return {}; + } + } + + using namespace Aws::STS::Model; + AssumeRoleWithWebIdentityRequest assumeRoleRequest; + assumeRoleRequest + .WithRoleArn(profile.GetRoleArn()) + .WithRoleSessionName(roleSessionName) + .WithWebIdentityToken(token) + .WithDurationSeconds(static_cast(std::chrono::seconds(m_duration).count())); + auto outcome = client->AssumeRoleWithWebIdentity(assumeRoleRequest); + if (outcome.IsSuccess()) + { + const auto& modelCredentials = outcome.GetResult().GetCredentials(); + return {modelCredentials.GetAccessKeyId(), + modelCredentials.GetSecretAccessKey(), + modelCredentials.GetSessionToken(), + modelCredentials.GetExpiration()}; + } + else + { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Failed to assume role " << profile.GetRoleArn()); + } + return {}; +} + +AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromWebIdentity(const Config::Profile& profile) +{ + using namespace Aws::STS::Model; + if (m_stsClientFactory) { + auto client = m_stsClientFactory({}); + return GetCredentialsFromWebIdentityInternal(profile, client.get()); + } + + Aws::STS::STSClient stsClient{AWSCredentials{}}; + return GetCredentialsFromWebIdentityInternal(profile, &stsClient); +} diff --git a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp index 197535a6a2e..cf234107db0 100644 --- a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp +++ b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp @@ -313,7 +313,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARN) STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -383,7 +383,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutSourceProfile) STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -409,7 +409,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithNonExistentSourceProfile STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -556,7 +556,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile Model::AssumeRoleResult mockResult; mockResult.SetCredentials(stsCredentials); - Aws::UniquePtr stsClient; + std::shared_ptr stsClient; int stsCallCounter = 0; @@ -572,9 +572,9 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str()); EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str()); } - stsClient = Aws::MakeUnique(CLASS_TAG, creds); + stsClient = Aws::MakeShared(CLASS_TAG, creds); stsClient->MockAssumeRole(mockResult); - return stsClient.get(); + return stsClient; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -614,7 +614,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials();