Skip to content

Commit

Permalink
Change ICurrentClientIdProvider to ICurrentUserProvider (#1660)
Browse files Browse the repository at this point in the history
  • Loading branch information
gunndabad authored Nov 7, 2024
1 parent c2a910f commit 55be8af
Show file tree
Hide file tree
Showing 19 changed files with 103 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace TeachingRecordSystem.Api.Infrastructure.RateLimiting;

public static class ServiceCollectionExtensions
{
private const string UnknownClientId = "__UNKNOWN__";
private const string UnknownUserPartitionKey = "__UNKNOWN__";

public static IServiceCollection AddRateLimiting(
this IServiceCollection services,
Expand Down Expand Up @@ -52,13 +52,13 @@ public static IServiceCollection AddRateLimiting(
{
var rateLimiterOptions = httpContext.RequestServices.GetRequiredService<IOptions<ClientIdRateLimiterOptions>>().Value;
var connectionMultiplexerFactory = () => httpContext.RequestServices.GetRequiredService<IConnectionMultiplexer>();
var clientId = ClaimsPrincipalCurrentClientProvider.GetCurrentClientIdFromHttpContext(httpContext) ?? UnknownClientId;
var clientRateLimit = rateLimiterOptions.ClientRateLimits.TryGetValue(clientId, out var windowOptions) ? windowOptions : rateLimiterOptions.DefaultRateLimit;
var partitionKey = ClaimsPrincipalCurrentUserProvider.TryGetCurrentClientIdFromHttpContext(httpContext, out var userId) ? userId.ToString() : UnknownUserPartitionKey;
var clientRateLimit = rateLimiterOptions.ClientRateLimits.TryGetValue(partitionKey, out var windowOptions) ? windowOptions : rateLimiterOptions.DefaultRateLimit;

// Window isn't available via RateLimitMetadata so stash it on the HttpContext instead
httpContext.Items.TryAdd(_windowSecondsHttpContextKey, clientRateLimit.Window.TotalSeconds);

return RedisRateLimitPartition.GetFixedWindowRateLimiter(clientId, key => new RedisFixedWindowRateLimiterOptions()
return RedisRateLimitPartition.GetFixedWindowRateLimiter(partitionKey, key => new RedisFixedWindowRateLimiterOptions()
{
Window = clientRateLimit.Window,
PermitLimit = clientRateLimit.PermitLimit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ protected async override Task<AuthenticateResult> HandleAuthenticateAsync()

var applicationUser = apiKey.ApplicationUser;

var principal = CreatePrincipal(applicationUser.UserId.ToString(), applicationUser.Name, applicationUser.ApiRoles ?? []);
var principal = CreatePrincipal(applicationUser.UserId, applicationUser.Name, applicationUser.ApiRoles ?? []);
var ticket = new AuthenticationTicket(principal, Scheme.Name);

LogContext.PushProperty("ApplicationUserId", applicationUser.UserId);
Expand All @@ -84,11 +84,11 @@ protected async override Task<AuthenticateResult> HandleAuthenticateAsync()
return AuthenticateResult.Success(ticket);
}

public static ClaimsPrincipal CreatePrincipal(string clientId, string name, IEnumerable<string> roles)
public static ClaimsPrincipal CreatePrincipal(Guid applicationUserId, string name, IEnumerable<string> roles)
{
var identity = new ClaimsIdentity(
[
new Claim("sub", clientId),
new Claim("sub", applicationUserId.ToString()),
new Claim(ClaimTypes.Name, name)
],
authenticationType: "Bearer",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System.Security.Claims;

namespace TeachingRecordSystem.Api.Infrastructure.Security;

public class ClaimsPrincipalCurrentUserProvider(IHttpContextAccessor httpContextAccessor) : ICurrentUserProvider
{
public static bool TryGetCurrentClientIdFromHttpContext(HttpContext httpContext, out Guid userId)
{
var userIdStr = httpContext.User.FindFirstValue("sub");

if (userIdStr is null)
{
userId = default;
return false;
}

return Guid.TryParse(userIdStr, out userId);
}

public Guid GetCurrentApplicationUserId()
{
var httpContext = httpContextAccessor.HttpContext ?? throw new Exception("No HttpContext.");

if (!TryGetCurrentClientIdFromHttpContext(httpContext, out var userId))
{
throw new Exception("Current user has no 'sub' claim.");
}

return userId;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace TeachingRecordSystem.Api.Infrastructure.Security;

public interface ICurrentUserProvider
{
Guid GetCurrentApplicationUserId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public static void Main(string[] args)
});

services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining<Program>());
services.AddSingleton<ICurrentClientProvider, ClaimsPrincipalCurrentClientProvider>();
services.AddSingleton<ICurrentUserProvider, ClaimsPrincipalCurrentUserProvider>();
services.AddMemoryCache();
services.AddSingleton<AddTrnToSentryScopeResourceFilter>();
services.AddTransient<TrnRequestHelper>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ namespace TeachingRecordSystem.Api;

public class TrnRequestHelper(TrsDbContext dbContext, ICrmQueryDispatcher crmQueryDispatcher)
{
public async Task<GetTrnRequestResult?> GetTrnRequestInfo(string currentClientId, string requestId)
public async Task<GetTrnRequestResult?> GetTrnRequestInfo(Guid currentApplicationUserId, string requestId)
{
var getDbTrnRequestTask = dbContext.TrnRequests.SingleOrDefaultAsync(r => r.ClientId == currentClientId && r.RequestId == requestId);
var getDbTrnRequestTask = dbContext.TrnRequests.SingleOrDefaultAsync(r => r.ClientId == currentApplicationUserId.ToString() && r.RequestId == requestId);

var crmTrnRequestId = GetCrmTrnRequestId(currentClientId, requestId);
var crmTrnRequestId = GetCrmTrnRequestId(currentApplicationUserId, requestId);
var getContactByTrnRequestIdTask = crmQueryDispatcher.ExecuteQuery(
new GetContactByTrnRequestIdQuery(crmTrnRequestId, new Microsoft.Xrm.Sdk.Query.ColumnSet(Contact.Fields.ContactId, Contact.Fields.dfeta_TrnToken)));

Expand All @@ -29,8 +29,8 @@ public class TrnRequestHelper(TrsDbContext dbContext, ICrmQueryDispatcher crmQue
return null;
}

public static string GetCrmTrnRequestId(string currentClientId, string requestId) =>
$"{currentClientId}::{requestId}";
public static string GetCrmTrnRequestId(Guid currentApplicationUserId, string requestId) =>
$"{currentApplicationUserId}::{requestId}";
}

public record GetTrnRequestResult(Guid ContactId, string? TrnToken);
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,40 @@ public class GetOrCreateTrnRequestHandler : IRequestHandler<GetOrCreateTrnReques

private readonly TrnRequestHelper _trnRequestHelper;
private readonly IDataverseAdapter _dataverseAdapter;
private readonly ICurrentClientProvider _currentClientProvider;
private readonly ICurrentUserProvider _currentUserProvider;
private readonly IDistributedLockProvider _distributedLockProvider;
private readonly IGetAnIdentityApiClient _identityApiClient;
private readonly AccessYourTeachingQualificationsOptions _accessYourTeachingQualificationsOptions;

public GetOrCreateTrnRequestHandler(
TrnRequestHelper trnRequestHelper,
IDataverseAdapter dataverseAdapter,
ICurrentClientProvider currentClientProvider,
ICurrentUserProvider currentUserProvider,
IDistributedLockProvider distributedLockProvider,
IGetAnIdentityApiClient identityApiClient,
IOptions<AccessYourTeachingQualificationsOptions> accessYourTeachingQualificationsOptions)
{
_trnRequestHelper = trnRequestHelper;
_dataverseAdapter = dataverseAdapter;
_currentClientProvider = currentClientProvider;
_currentUserProvider = currentUserProvider;
_distributedLockProvider = distributedLockProvider;
_identityApiClient = identityApiClient;
_accessYourTeachingQualificationsOptions = accessYourTeachingQualificationsOptions.Value;
}

public async Task<TrnRequestInfo> Handle(GetOrCreateTrnRequest request, CancellationToken cancellationToken)
{
var currentClientId = _currentClientProvider.GetCurrentClientId();
var currentApplicationUserId = _currentUserProvider.GetCurrentApplicationUserId();

await using var requestIdLock = await _distributedLockProvider.AcquireLockAsync(
DistributedLockKeys.TrnRequestId(currentClientId, request.RequestId),
DistributedLockKeys.TrnRequestId(currentApplicationUserId, request.RequestId),
_lockTimeout);

await using var husidLock = !string.IsNullOrEmpty(request.HusId) ?
(IAsyncDisposable)await _distributedLockProvider.AcquireLockAsync(DistributedLockKeys.Husid(request.HusId), _lockTimeout) :
NoopAsyncDisposable.Instance;

var trnRequest = await _trnRequestHelper.GetTrnRequestInfo(currentClientId, request.RequestId);
var trnRequest = await _trnRequestHelper.GetTrnRequestInfo(currentApplicationUserId, request.RequestId);

bool wasCreated;
string trn;
Expand Down Expand Up @@ -136,7 +136,7 @@ public async Task<TrnRequestInfo> Handle(GetOrCreateTrnRequest request, Cancella
InductionRequired = request.InductionRequired,
UnderNewOverseasRegulations = request.UnderNewOverseasRegulations,
SlugId = request.SlugId,
TrnRequestId = TrnRequestHelper.GetCrmTrnRequestId(currentClientId, request.RequestId),
TrnRequestId = TrnRequestHelper.GetCrmTrnRequestId(currentApplicationUserId, request.RequestId),
GetTrnToken = GetTrnToken
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@ public class GetTrnRequestHandler : IRequestHandler<GetTrnRequest, TrnRequestInf
{
private readonly TrnRequestHelper _trnRequestHelper;
private readonly IDataverseAdapter _dataverseAdapter;
private readonly ICurrentClientProvider _currentClientProvider;
private readonly ICurrentUserProvider _currentUserProvider;
private readonly AccessYourTeachingQualificationsOptions _accessYourTeachingQualificationsOptions;

public GetTrnRequestHandler(
TrnRequestHelper trnRequestHelper,
TrsDbContext TrsDbContext,
IDataverseAdapter dataverseAdapter,
ICurrentClientProvider currentClientProvider,
ICurrentUserProvider currentUserProvider,
IOptions<AccessYourTeachingQualificationsOptions> accessYourTeachingQualificationsOptions)
{
_trnRequestHelper = trnRequestHelper;
_dataverseAdapter = dataverseAdapter;
_currentClientProvider = currentClientProvider;
_currentUserProvider = currentUserProvider;
_accessYourTeachingQualificationsOptions = accessYourTeachingQualificationsOptions.Value;
}

public async Task<TrnRequestInfo> Handle(GetTrnRequest request, CancellationToken cancellationToken)
{
var currentClientId = _currentClientProvider.GetCurrentClientId();
var currentApplicationUserId = _currentUserProvider.GetCurrentApplicationUserId();

var trnRequest = await _trnRequestHelper.GetTrnRequestInfo(currentClientId, request.RequestId);
var trnRequest = await _trnRequestHelper.GetTrnRequestInfo(currentApplicationUserId, request.RequestId);
if (trnRequest == null)
{
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ public class CreateTrnRequestHandler(
TrsDbContext dbContext,
ICrmQueryDispatcher crmQueryDispatcher,
TrnRequestHelper trnRequestHelper,
ICurrentClientProvider currentClientProvider,
ICurrentUserProvider currentUserProvider,
ITrnGenerationApiClient trnGenerationApiClient,
INameSynonymProvider nameSynonymProvider)
{
public async Task<TrnRequestInfo> Handle(CreateTrnRequestCommand command)
{
var currentClientId = currentClientProvider.GetCurrentClientId();
var currentApplicationUserId = currentUserProvider.GetCurrentApplicationUserId();

var trnRequest = await trnRequestHelper.GetTrnRequestInfo(currentClientId, command.RequestId);
var trnRequest = await trnRequestHelper.GetTrnRequestInfo(currentApplicationUserId, command.RequestId);
if (trnRequest is not null)
{
throw new ErrorException(ErrorRegistry.CannotResubmitRequest());
Expand Down Expand Up @@ -116,7 +116,7 @@ await crmQueryDispatcher.ExecuteQuery(new CreateContactQuery()
NationalInsuranceNumber = NationalInsuranceNumberHelper.Normalize(command.NationalInsuranceNumber),
PotentialDuplicates = potentialDuplicates,
Trn = trn,
TrnRequestId = TrnRequestHelper.GetCrmTrnRequestId(currentClientId, command.RequestId),
TrnRequestId = TrnRequestHelper.GetCrmTrnRequestId(currentApplicationUserId, command.RequestId),
});

var status = trn is not null ? TrnRequestStatus.Completed : TrnRequestStatus.Pending;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ public record GetTrnRequestCommand(string RequestId);
public class GetTrnRequestHandler(
ICrmQueryDispatcher crmQueryDispatcher,
TrnRequestHelper trnRequestHelper,
ICurrentClientProvider currentClientProvider)
ICurrentUserProvider currentUserProvider)
{
public async Task<TrnRequestInfo?> Handle(GetTrnRequestCommand command)
{
var currentClientId = currentClientProvider.GetCurrentClientId();
var currentApplicationUserId = currentUserProvider.GetCurrentApplicationUserId();

var trnRequest = await trnRequestHelper.GetTrnRequestInfo(currentClientId, command.RequestId);
var trnRequest = await trnRequestHelper.GetTrnRequestInfo(currentApplicationUserId, command.RequestId);
if (trnRequest is null)
{
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public static class DistributedLockKeys
public static string EntityChanges(string changesKey, string entityLogicalName) => $"entity-changes:{changesKey}/{entityLogicalName}";
public static string Husid(string husid) => $"husid:{husid}";
public static string Trn(string trn) => $"trn:{trn}";
public static string TrnRequestId(string clientId, string requestId) => $"trn-request:{clientId}/{requestId}";
public static string TrnRequestId(Guid applicationUserId, string requestId) => $"trn-request:{applicationUserId}/{requestId}";
public static string DqtReportingReplicationSlot() => nameof(DqtReportingReplicationSlot);
public static string DqtReportingMigrations() => nameof(DqtReportingMigrations);
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ protected override Task<AuthenticateResult> HandleAuthenticateAsync()
return Task.FromResult(AuthenticateResult.NoResult());
}

var currentApiClientId = currentApiClientProvider.CurrentApiClientId;
var currentApiClientId = currentApiClientProvider.CurrentApiUserId;
var currentRoles = currentApiClientProvider.Roles!;

if (currentApiClientId is not null)
if (currentApiClientId is Guid id)
{
var principal = ApiKeyAuthenticationHandler.CreatePrincipal(currentApiClientId, name: currentApiClientId, currentRoles);
var principal = ApiKeyAuthenticationHandler.CreatePrincipal(id, name: id.ToString(), currentRoles);

var ticket = new AuthenticationTicket(principal, Scheme.Name);

Expand All @@ -42,14 +42,14 @@ public class TestApiKeyAuthenticationOptions : AuthenticationSchemeOptions { }

public class CurrentApiClientProvider
{
private readonly AsyncLocal<string> _currentApiClientId = new();
private readonly AsyncLocal<Guid?> _currentApiUserId = new();
private readonly AsyncLocal<string[]> _roles = new();

[DisallowNull]
public string? CurrentApiClientId
public Guid? CurrentApiUserId
{
get => _currentApiClientId.Value;
set => _currentApiClientId.Value = value;
get => _currentApiUserId.Value;
set => _currentApiUserId.Value = value;
}

[DisallowNull]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@ namespace TeachingRecordSystem.Api.Tests;

public abstract class TestBase
{
private static readonly Guid _defaultApplicationUserId = new Guid("c0c8c511-e8e4-4b8e-96e3-55085dafc05d");

private readonly TestScopedServices _testServices;

protected TestBase(HostFixture hostFixture)
{
HostFixture = hostFixture;
_testServices = TestScopedServices.Reset();
SetCurrentApiClient(Array.Empty<string>());
SetCurrentApiClient([]);
}

public HostFixture HostFixture { get; }

public Mock<ICertificateGenerator> CertificateGeneratorMock => _testServices.CertificateGeneratorMock;

public string ClientId { get; } = "tests";
public Guid ApplicationUserId { get; } = _defaultApplicationUserId;

public CrmQueryDispatcherSpy CrmQueryDispatcherSpy => _testServices.CrmQueryDispatcherSpy;

Expand Down Expand Up @@ -97,10 +99,10 @@ public HttpClient GetHttpClientWithIdentityAccessToken(string trn, string scope
return httpClient;
}

protected void SetCurrentApiClient(IEnumerable<string> roles, string clientId = "tests")
protected void SetCurrentApiClient(IEnumerable<string> roles, Guid? applicationUserId = null)
{
var currentUserProvider = HostFixture.Services.GetRequiredService<CurrentApiClientProvider>();
currentUserProvider.CurrentApiClientId = clientId;
currentUserProvider.CurrentApiUserId = applicationUserId ?? _defaultApplicationUserId;
currentUserProvider.Roles = roles.ToArray();
}

Expand Down
Loading

0 comments on commit 55be8af

Please sign in to comment.