From 018aaea871930cf2bd4574a1c0707c89b82f7d7c Mon Sep 17 00:00:00 2001 From: Alex Redding Date: Fri, 10 May 2024 01:07:28 -0500 Subject: [PATCH] fix: Remove early returns in claims transformation (#99) * fix: Remove early returns in claims transformation will allow it check all flags --- .../KeycloakRolesClaimsTransformation.cs | 96 +++++++++---------- .../KeycloakRolesClaimsTransformationTests.cs | 65 ++++++++++--- 2 files changed, 99 insertions(+), 62 deletions(-) diff --git a/src/Keycloak.AuthServices.Authorization/Claims/KeycloakRolesClaimsTransformation.cs b/src/Keycloak.AuthServices.Authorization/Claims/KeycloakRolesClaimsTransformation.cs index d168eb63..d5ac23c9 100644 --- a/src/Keycloak.AuthServices.Authorization/Claims/KeycloakRolesClaimsTransformation.cs +++ b/src/Keycloak.AuthServices.Authorization/Claims/KeycloakRolesClaimsTransformation.cs @@ -73,38 +73,34 @@ public Task TransformAsync(ClaimsPrincipal principal) if (this.roleSource.HasFlag(RolesClaimTransformationSource.ResourceAccess)) { var resourceAccessValue = principal.FindFirst("resource_access")?.Value; - if (string.IsNullOrWhiteSpace(resourceAccessValue)) + if (!string.IsNullOrWhiteSpace(resourceAccessValue)) { - return Task.FromResult(result); - } - - using var resourceAccess = JsonDocument.Parse(resourceAccessValue); - var containsAudienceRoles = resourceAccess.RootElement.TryGetProperty( - this.audience, - out var rolesElement - ); - - if (!containsAudienceRoles) - { - return Task.FromResult(result); - } - - var clientRoles = rolesElement.GetProperty("roles"); - - foreach (var role in clientRoles.EnumerateArray()) - { - var value = role.GetString(); - - var matchingClaim = identity.Claims.FirstOrDefault(claim => - claim.Type.Equals( - this.roleClaimType, - StringComparison.InvariantCultureIgnoreCase - ) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase) + using var resourceAccess = JsonDocument.Parse(resourceAccessValue); + var containsAudienceRoles = resourceAccess.RootElement.TryGetProperty( + this.audience, + out var rolesElement ); - if (matchingClaim is null && !string.IsNullOrWhiteSpace(value)) + if (containsAudienceRoles) { - identity.AddClaim(new Claim(this.roleClaimType, value)); + var clientRoles = rolesElement.GetProperty("roles"); + + foreach (var role in clientRoles.EnumerateArray()) + { + var value = role.GetString(); + + var matchingClaim = identity.Claims.FirstOrDefault(claim => + claim.Type.Equals( + this.roleClaimType, + StringComparison.InvariantCultureIgnoreCase + ) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase) + ); + + if (matchingClaim is null && !string.IsNullOrWhiteSpace(value)) + { + identity.AddClaim(new Claim(this.roleClaimType, value)); + } + } } } } @@ -112,34 +108,32 @@ out var rolesElement if (this.roleSource.HasFlag(RolesClaimTransformationSource.Realm)) { var realmAccessValue = principal.FindFirst("realm_access")?.Value; - if (string.IsNullOrWhiteSpace(realmAccessValue)) + if (!string.IsNullOrWhiteSpace(realmAccessValue)) { - return Task.FromResult(result); - } - - using var realmAccess = JsonDocument.Parse(realmAccessValue); + using var realmAccess = JsonDocument.Parse(realmAccessValue); - var containsRoles = realmAccess.RootElement.TryGetProperty( - "roles", - out var rolesElement - ); + var containsRoles = realmAccess.RootElement.TryGetProperty( + "roles", + out var rolesElement + ); - if (containsRoles) - { - foreach (var role in rolesElement.EnumerateArray()) + if (containsRoles) { - var value = role.GetString(); - - var matchingClaim = identity.Claims.FirstOrDefault(claim => - claim.Type.Equals( - this.roleClaimType, - StringComparison.InvariantCultureIgnoreCase - ) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase) - ); - - if (matchingClaim is null && !string.IsNullOrWhiteSpace(value)) + foreach (var role in rolesElement.EnumerateArray()) { - identity.AddClaim(new Claim(this.roleClaimType, value)); + var value = role.GetString(); + + var matchingClaim = identity.Claims.FirstOrDefault(claim => + claim.Type.Equals( + this.roleClaimType, + StringComparison.InvariantCultureIgnoreCase + ) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase) + ); + + if (matchingClaim is null && !string.IsNullOrWhiteSpace(value)) + { + identity.AddClaim(new Claim(this.roleClaimType, value)); + } } } } diff --git a/tests/Keycloak.AuthServices.Authorization.Tests/Claims/KeycloakRolesClaimsTransformationTests.cs b/tests/Keycloak.AuthServices.Authorization.Tests/Claims/KeycloakRolesClaimsTransformationTests.cs index 416bbffa..a8c1b301 100644 --- a/tests/Keycloak.AuthServices.Authorization.Tests/Claims/KeycloakRolesClaimsTransformationTests.cs +++ b/tests/Keycloak.AuthServices.Authorization.Tests/Claims/KeycloakRolesClaimsTransformationTests.cs @@ -82,6 +82,38 @@ public async Task ClaimsTransformationShouldHandleMissingResourceClaim() claimsPrincipal.Claims.Count(item => ClaimTypes.Role == item.Type).Should().Be(0); } + [Fact] + public async Task ClaimsTransformationShouldHandleMissingResourceClaimWithRealmRoles() + { + var target = new KeycloakRolesClaimsTransformation( + ClaimTypes.Role, + RolesClaimTransformationSource.All, + ClientId + ); + var claimsPrincipal = GetClaimsPrincipal(MyRealmClaimValue, null); + + claimsPrincipal = await target.TransformAsync(claimsPrincipal); + claimsPrincipal.HasClaim(ClaimTypes.Role, RealmRoleUserClaim).Should().BeTrue(); + claimsPrincipal.HasClaim(ClaimTypes.Role, RealmRoleSuperUserClaim).Should().BeTrue(); + claimsPrincipal.Claims.Count(item => ClaimTypes.Role == item.Type).Should().Be(2); + } + + [Fact] + public async Task ClaimsTransformationShouldHandleMissingRealmClaimWithResourceRoles() + { + var target = new KeycloakRolesClaimsTransformation( + ClaimTypes.Role, + RolesClaimTransformationSource.All, + ClientId + ); + var claimsPrincipal = GetClaimsPrincipal(null, MyResourceClaimValue); + + claimsPrincipal = await target.TransformAsync(claimsPrincipal); + claimsPrincipal.HasClaim(ClaimTypes.Role, AppRoleUserClaim).Should().BeTrue(); + claimsPrincipal.HasClaim(ClaimTypes.Role, AppRoleSuperUserClaim).Should().BeTrue(); + claimsPrincipal.Claims.Count(item => ClaimTypes.Role == item.Type).Should().Be(2); + } + private const string MyResourceClaimValue = /*lang=json,strict*/ """ { @@ -138,17 +170,28 @@ public async Task ClaimsTransformationShouldHandleMissingResourceClaim() // Get a claims principal that has all the appropriate claim details required for testing private static ClaimsPrincipal GetClaimsPrincipal( - string realmClaimValue, - string resourceClaimValue - ) => - new( - new ClaimsIdentity( - [ - new Claim(ResourceClaimType, resourceClaimValue, JsonValueType, MyUrl, MyUrl), - new Claim(RealmClaimType, realmClaimValue, JsonValueType, MyUrl, MyUrl), - ] - ) - ); + string? realmClaimValue, + string? resourceClaimValue + ) + { + var claimsIdentity = new ClaimsIdentity(); + + if (realmClaimValue != null) + { + claimsIdentity.AddClaim( + new Claim(RealmClaimType, realmClaimValue, JsonValueType, MyUrl, MyUrl) + ); + } + + if (resourceClaimValue != null) + { + claimsIdentity.AddClaim( + new Claim(ResourceClaimType, resourceClaimValue, JsonValueType, MyUrl, MyUrl) + ); + } + + return new ClaimsPrincipal(claimsIdentity); + } // Get a claims principal that has all the appropriate claim details required for testing private static ClaimsPrincipal GetClaimsPrincipalClaim(string claimValue) =>