From c221644514a908ccc120316e09d17937189be097 Mon Sep 17 00:00:00 2001 From: Kevin Stapleton <3208714+kevinstapleton@users.noreply.github.com> Date: Tue, 15 Oct 2024 10:43:23 -0500 Subject: [PATCH] Perform background refresh of credentials during preempt expiry time period --- .../Credentials/RefreshingAWSCredentials.cs | 126 +++++++--- ...AWSSDK.UnitTests.Custom.NetStandard.csproj | 2 +- .../RefreshingAWSCredentialsTests.cs | 232 ++++++++++++++++++ 3 files changed, 329 insertions(+), 31 deletions(-) create mode 100644 sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs diff --git a/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs b/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs index fd5fa8ffddf4..5b5444830ff7 100644 --- a/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs +++ b/sdk/src/Core/Amazon.Runtime/Credentials/RefreshingAWSCredentials.cs @@ -36,11 +36,7 @@ public abstract class RefreshingAWSCredentials : AWSCredentials, IDisposable /// public class CredentialsRefreshState { - public ImmutableCredentials Credentials - { - get; - set; - } + public ImmutableCredentials Credentials { get; set; } public DateTime Expiration { get; set; } public CredentialsRefreshState() @@ -61,6 +57,16 @@ internal bool IsExpiredWithin(TimeSpan preemptExpiryTime) var exp = Expiration.ToUniversalTime(); return (now > exp - preemptExpiryTime); } + + internal TimeSpan GetTimeToLive(TimeSpan preemptExpiryTime) + { +#pragma warning disable CS0612,CS0618 // Type or member is obsolete + var now = AWSSDKUtils.CorrectedUtcNow; +#pragma warning restore CS0612,CS0618 // Type or member is obsolete + var exp = Expiration.ToUniversalTime(); + + return exp - now + preemptExpiryTime; + } } /// @@ -120,46 +126,106 @@ public TimeSpan PreemptExpiryTime /// public override ImmutableCredentials GetCredentials() { - _updateGeneratedCredentialsSemaphore.Wait(); - try + // We save the currentState as it might be modified or cleared. + var tempState = currentState; + + var ttl = tempState?.GetTimeToLive(PreemptExpiryTime); + + if (ttl > TimeSpan.Zero) { - // We save the currentState as it might be modified or cleared. - var tempState = currentState; - // If credentials are expired or we don't have any state yet, update - if (ShouldUpdateState(tempState, PreemptExpiryTime)) + if (ttl < PreemptExpiryTime) { - tempState = GenerateNewCredentials(); - UpdateToGeneratedCredentials(tempState, PreemptExpiryTime); - currentState = tempState; + // background refresh (fire & forget) + if (_updateGeneratedCredentialsSemaphore.Wait(0)) + { + _ = System.Threading.Tasks.Task.Run(GenerateCredentialsAndUpdateState); + } } - return tempState.Credentials.Copy(); } - finally + else { - _updateGeneratedCredentialsSemaphore.Release(); + // If credentials are expired, update + _updateGeneratedCredentialsSemaphore.Wait(); + tempState = GenerateCredentialsAndUpdateState(); + } + + return tempState.Credentials.Copy(); + + CredentialsRefreshState GenerateCredentialsAndUpdateState() + { + System.Diagnostics.Debug.Assert(_updateGeneratedCredentialsSemaphore.CurrentCount == 0); + + try + { + var tempState = currentState; + // double-check that the credentials still need updating + // as it's possible that multiple requests were queued acquiring the semaphore + if (ShouldUpdateState(tempState, PreemptExpiryTime)) + { + tempState = GenerateNewCredentials(); + UpdateToGeneratedCredentials(tempState, PreemptExpiryTime); + currentState = tempState; + } + + return tempState; + } + finally + { + _updateGeneratedCredentialsSemaphore.Release(); + } } } #if AWS_ASYNC_API public override async System.Threading.Tasks.Task GetCredentialsAsync() { - await _updateGeneratedCredentialsSemaphore.WaitAsync().ConfigureAwait(false); - try + // We save the currentState as it might be modified or cleared. + var tempState = currentState; + + var ttl = tempState?.GetTimeToLive(PreemptExpiryTime); + + if (ttl > TimeSpan.Zero) { - // We save the currentState as it might be modified or cleared. - var tempState = currentState; - // If credentials are expired, update - if (ShouldUpdateState(tempState, PreemptExpiryTime)) + if (ttl < PreemptExpiryTime) { - tempState = await GenerateNewCredentialsAsync().ConfigureAwait(false); - UpdateToGeneratedCredentials(tempState, PreemptExpiryTime); - currentState = tempState; + // background refresh (fire & forget) + if (_updateGeneratedCredentialsSemaphore.Wait(0)) + { + _ = GenerateCredentialsAndUpdateStateAsync(); + } } - return tempState.Credentials.Copy(); } - finally + else + { + // If credentials are expired, update + await _updateGeneratedCredentialsSemaphore.WaitAsync().ConfigureAwait(false); + tempState = await GenerateCredentialsAndUpdateStateAsync().ConfigureAwait(false); + } + + return tempState.Credentials.Copy(); + + async System.Threading.Tasks.Task GenerateCredentialsAndUpdateStateAsync() { - _updateGeneratedCredentialsSemaphore.Release(); + System.Diagnostics.Debug.Assert(_updateGeneratedCredentialsSemaphore.CurrentCount == 0); + + try + { + var tempState = currentState; + // double-check that the credentials still need updating + // as it's possible that multiple requests were queued acquiring the semaphore + if (ShouldUpdateState(tempState, PreemptExpiryTime)) + { + tempState = await GenerateNewCredentialsAsync().ConfigureAwait(false); + UpdateToGeneratedCredentials(tempState, PreemptExpiryTime); + currentState = tempState; + } + + return tempState; + } + finally + { + _updateGeneratedCredentialsSemaphore.Release(); + } } } #endif @@ -262,7 +328,7 @@ protected virtual CredentialsRefreshState GenerateNewCredentials() throw new NotImplementedException(); } #if AWS_ASYNC_API - /// + /// /// When overridden in a derived class, generates new credentials and new expiration date. /// /// Called on first credentials request and when expiration date is in the past. diff --git a/sdk/test/NetStandard/UnitTests/AWSSDK.UnitTests.Custom.NetStandard.csproj b/sdk/test/NetStandard/UnitTests/AWSSDK.UnitTests.Custom.NetStandard.csproj index 3015ffa7ab62..afb69884c4c9 100644 --- a/sdk/test/NetStandard/UnitTests/AWSSDK.UnitTests.Custom.NetStandard.csproj +++ b/sdk/test/NetStandard/UnitTests/AWSSDK.UnitTests.Custom.NetStandard.csproj @@ -59,7 +59,7 @@ This project file should not be used as part of a release pipeline. - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs b/sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs new file mode 100644 index 000000000000..3c1f2a7db29d --- /dev/null +++ b/sdk/test/NetStandard/UnitTests/Core/Credentials/RefreshingAWSCredentialsTests.cs @@ -0,0 +1,232 @@ +using System; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Amazon; +using Amazon.Runtime; +using Xunit; + +namespace UnitTests.NetStandard.Core.Credentials +{ + public sealed class RefreshingAWSCredentialsTests : IDisposable + { + private readonly Func _resetUtcNowSource; + + public RefreshingAWSCredentialsTests() + { + _resetUtcNowSource = AWSConfigs.utcNowSource; + } + + [Fact] + public void ConcurrentCallsToGetExpiredCrendentialsOnlyGeneratesNewCredentialsOnce() + { + var baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + var lifetime = TimeSpan.FromMinutes(60); + var mockCredentials = new MockRefreshingAWSCredentials(lifetime) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + AWSConfigs.utcNowSource = () => baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + AWSConfigs.utcNowSource = () => baseTimeUtc + lifetime; + + mockCredentials.CloseGenerateCredentialsGate(); // prevent GenerateNewCredentials from returning + + var concurrentCredentialTasks = Task.WhenAll( + Enumerable.Range(1, 5) + .Select(i => Task.Run(() => mockCredentials.GetCredentials())) + ); + + mockCredentials.OpenGenerateCredentialsGate(); // allow GenerateNewCredentials to complete + + var allCreds = concurrentCredentialTasks.Result; + Assert.NotEqual(initialCreds, allCreds[0]); + + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + for (var i = 1; i < allCreds.Length; i++) + { + Assert.Equal(allCreds[0], allCreds[i]); + } + } + + [Fact] + public void CredentialsAreRefreshedInImmediatelyWhenExpired() + { + var baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + var lifetime = TimeSpan.FromMinutes(60); + var mockCredentials = new MockRefreshingAWSCredentials(lifetime) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + AWSConfigs.utcNowSource = () => baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + AWSConfigs.utcNowSource = () => baseTimeUtc + lifetime; + var credsAfterExpiration = mockCredentials.GetCredentials(); + Assert.NotEqual(initialCreds, credsAfterExpiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Fact] + public async Task CredentialsAreRefreshedInImmediatelyWhenExpiredAsync() + { + var baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + var lifetime = TimeSpan.FromMinutes(60); + var mockCredentials = new MockRefreshingAWSCredentials(lifetime) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + AWSConfigs.utcNowSource = () => baseTimeUtc; + var initialCreds = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + + AWSConfigs.utcNowSource = () => baseTimeUtc + lifetime; + var credsAfterExpiration = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + Assert.NotEqual(initialCreds, credsAfterExpiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Fact] + public async Task ConcurrentCallsToGetExpiredCrendentialsOnlyGeneratesNewCredentialsOnceAsync() + { + var baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + var lifetime = TimeSpan.FromMinutes(60); + var mockCredentials = new MockRefreshingAWSCredentials(lifetime) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + AWSConfigs.utcNowSource = () => baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + AWSConfigs.utcNowSource = () => baseTimeUtc + lifetime; + + mockCredentials.CloseGenerateCredentialsGate(); // prevent GenerateNewCredentials from returning + + var concurrentCredentialTasks = Task.WhenAll( + Enumerable.Range(1, 5) + .Select(i => mockCredentials.GetCredentialsAsync()) + ); + + mockCredentials.OpenGenerateCredentialsGate(); // allow GenerateNewCredentials to complete + + var allCreds = await concurrentCredentialTasks; + + Assert.NotEqual(initialCreds, allCreds[0]); + + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + for (var i = 1; i < allCreds.Length; i++) + { + Assert.Equal(allCreds[0], allCreds[i]); + } + } + + [Fact] + public void CredentialsAreRefreshedInBackgroundDuringPreemptyExpiryPeriod() + { + var baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + var lifetime = TimeSpan.FromMinutes(60); + var mockCredentials = new MockRefreshingAWSCredentials(lifetime) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + AWSConfigs.utcNowSource = () => baseTimeUtc; + var initialCreds = mockCredentials.GetCredentials(); + + AWSConfigs.utcNowSource = () => baseTimeUtc + TimeSpan.FromMinutes(50); + var previousState = mockCredentials.CurrentState; + var credsDuringPreemptExpiry = mockCredentials.GetCredentials(); + Assert.Equal(initialCreds, credsDuringPreemptExpiry); + + // wait for background refresh to complete + Assert.True(SpinWait.SpinUntil(() => !ReferenceEquals(mockCredentials.CurrentState, previousState), 1_000)); + + var credsAfterRefresh = mockCredentials.GetCredentials(); + Assert.NotEqual(credsAfterRefresh, credsDuringPreemptExpiry); + Assert.Equal(AWSConfigs.utcNowSource() + lifetime - mockCredentials.PreemptExpiryTime, mockCredentials.CurrentState.Expiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + [Fact] + public async Task CredentialsAreRefreshedInBackgroundDuringPreemptyExpiryPeriodAsync() + { + var baseTimeUtc = new DateTime(1970, 1, 1).ToUniversalTime(); + var lifetime = TimeSpan.FromMinutes(60); + var mockCredentials = new MockRefreshingAWSCredentials(lifetime) + { + PreemptExpiryTime = TimeSpan.FromMinutes(15), + }; + + AWSConfigs.utcNowSource = () => baseTimeUtc; + var initialCreds = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + + AWSConfigs.utcNowSource = () => baseTimeUtc + TimeSpan.FromMinutes(50); + var previousState = mockCredentials.CurrentState; + var credsDuringPreemptExpiry = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + Assert.Equal(initialCreds, credsDuringPreemptExpiry); + + // wait for background refresh to complete + Assert.True(SpinWait.SpinUntil(() => !ReferenceEquals(mockCredentials.CurrentState, previousState), 1_000)); + + var credsAfterRefresh = await mockCredentials.GetCredentialsAsync().ConfigureAwait(false); + Assert.NotEqual(credsAfterRefresh, credsDuringPreemptExpiry); + Assert.Equal(AWSConfigs.utcNowSource() + lifetime - mockCredentials.PreemptExpiryTime, mockCredentials.CurrentState.Expiration); + Assert.Equal(2, mockCredentials.GeneratedTokenCount); + } + + public void Dispose() + { + AWSConfigs.utcNowSource = _resetUtcNowSource; + } + + // using a hand-written mock in order to have access to the protected fields + private sealed class MockRefreshingAWSCredentials : RefreshingAWSCredentials + { + private readonly TimeSpan _credentialsLifetime; + private readonly ManualResetEventSlim _generateCredsEvent; + private int _tokenCounter; + + public MockRefreshingAWSCredentials(TimeSpan credentialsLifetime) + { + _credentialsLifetime = credentialsLifetime; + _generateCredsEvent = new ManualResetEventSlim(initialState: true); + _tokenCounter = 0; + } + + public CredentialsRefreshState CurrentState => base.currentState; + + public int GeneratedTokenCount => _tokenCounter; + + public bool IsGenerateCredentialsGateClosed => !_generateCredsEvent.IsSet; + + public void OpenGenerateCredentialsGate() + { + _generateCredsEvent.Set(); + } + + public void CloseGenerateCredentialsGate() + { + _generateCredsEvent.Reset(); + } + + protected override CredentialsRefreshState GenerateNewCredentials() + { + _generateCredsEvent.Wait(); + return new CredentialsRefreshState + { + Credentials = new ImmutableCredentials("access_key_id", "secret_access_key", $"token_{Interlocked.Increment(ref _tokenCounter)}"), + Expiration = AWSConfigs.utcNowSource() + _credentialsLifetime, + }; + } + + protected override Task GenerateNewCredentialsAsync() + { + return Task.Run(GenerateNewCredentials); + } + } + } +}