Skip to content

Commit

Permalink
Introduces a lock while loading credentials from Credential Source (#…
Browse files Browse the repository at this point in the history
…2438)

* Introduces a lock while loading credentials from Credential Source

* Make type nullable

* Updates Abstractions version and CredentialLoader implementation

* Remove unnecessary usings

* Address feedback
  • Loading branch information
sruke authored Sep 7, 2023
1 parent 52cec44 commit 050e3e0
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
<MicrosoftGraphVersion>4.34.0</MicrosoftGraphVersion>
<MicrosoftGraphBetaVersion>4.50.0-preview</MicrosoftGraphBetaVersion>
<MicrosoftExtensionsHttpVersion>3.1.3</MicrosoftExtensionsHttpVersion>
<MicrosoftIdentityAbstractions>4.0.0</MicrosoftIdentityAbstractions>
<MicrosoftIdentityAbstractions>4.1.0</MicrosoftIdentityAbstractions>
<!--CVE-2021-24112-->
<SystemDrawingCommon>4.7.2</SystemDrawingCommon>
</PropertyGroup>
Expand Down
23 changes: 20 additions & 3 deletions src/Microsoft.Identity.Web.Certificate/DefaultCredentialsLoader.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Identity.Abstractions;
Expand All @@ -15,6 +16,7 @@ namespace Microsoft.Identity.Web
public class DefaultCredentialsLoader : ICredentialsLoader
{
ILogger<DefaultCredentialsLoader>? _logger;
private readonly ConcurrentDictionary<string, SemaphoreSlim> _loadingSemaphores = new ConcurrentDictionary<string, SemaphoreSlim>();

/// <summary>
/// Constructor with a logger
Expand Down Expand Up @@ -56,9 +58,24 @@ public async Task LoadCredentialsIfNeededAsync(CredentialDescription credentialD

if (credentialDescription.CachedValue == null)
{
if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader))
// Get or create a semaphore for this credentialDescription
var semaphore = _loadingSemaphores.GetOrAdd(credentialDescription.Id, (v) => new SemaphoreSlim(1));

// Wait to acquire the semaphore
await semaphore.WaitAsync();

try
{
if (credentialDescription.CachedValue == null)
{
if (CredentialSourceLoaders.TryGetValue(credentialDescription.SourceType, out ICredentialSourceLoader? loader))
await loader.LoadIfNeededAsync(credentialDescription, parameters);
}
}
finally
{
await loader.LoadIfNeededAsync(credentialDescription, parameters);
// Release the semaphore
semaphore.Release();
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions tests/IntegrationTests/TokenAcquirerTests/TokenAcquirer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using Microsoft.Identity.Web.TokenCacheProviders.InMemory;
using Microsoft.IdentityModel.Tokens;
using Xunit;
using TaskStatus = System.Threading.Tasks.TaskStatus;

namespace TokenAcquirerTests
{
Expand Down Expand Up @@ -126,6 +127,46 @@ public async Task AcquireToken_WithFactoryAndAuthorityClientIdCert_ClientCredent
Assert.False(string.IsNullOrEmpty(result.AccessToken));
}

[IgnoreOnAzureDevopsFact]
//[Fact]
public async Task LoadCredentialsIfNeededAsync_MultipleThreads_WaitsForSemaphore()
{
TokenAcquirerFactory tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance();
IServiceCollection services = tokenAcquirerFactory.Services;

services.Configure<MicrosoftIdentityApplicationOptions>(s_optionName, option =>
{
option.Instance = "https://login.microsoftonline.com/";
option.TenantId = "msidentitysamplestesting.onmicrosoft.com";
option.ClientId = "6af093f3-b445-4b7a-beae-046864468ad6";
option.ClientCredentials = s_clientCredentials;
});

services.AddInMemoryTokenCaches();
var serviceProvider = tokenAcquirerFactory.Build();
var options = serviceProvider.GetRequiredService<IOptionsMonitor<MicrosoftIdentityApplicationOptions>>().Get(s_optionName);
var credentialsLoader = serviceProvider.GetRequiredService<ICredentialsLoader>();

var task1 = Task.Run(async () =>
{
await credentialsLoader.LoadCredentialsIfNeededAsync(options.ClientCredentials!.First());
});

var task2 = Task.Run(async () =>
{
await credentialsLoader.LoadCredentialsIfNeededAsync(options.ClientCredentials!.First());
});

// Run task1 and task2 concurrently
await Task.WhenAll(task1, task2);

var cert = options.ClientCredentials!.First().Certificate;

Assert.NotNull(cert);
Assert.Equal(TaskStatus.RanToCompletion, task1.Status);
Assert.Equal(TaskStatus.RanToCompletion, task2.Status);
}

[IgnoreOnAzureDevopsFact]
//[Fact]
public async Task AcquireTokenWithPop_ClientCredentialsAsync()
Expand Down

0 comments on commit 050e3e0

Please sign in to comment.