Skip to content

Commit

Permalink
fix: Remove early returns in claims transformation (#99)
Browse files Browse the repository at this point in the history
* fix: Remove early returns in claims transformation will allow it check all flags
  • Loading branch information
Alexr03 committed May 10, 2024
1 parent 056e649 commit 018aaea
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,73 +73,67 @@ public Task<ClaimsPrincipal> TransformAsync(ClaimsPrincipal principal)
if (this.roleSource.HasFlag(RolesClaimTransformationSource.ResourceAccess))
{
var resourceAccessValue = principal.FindFirst("resource_access")?.Value;
if (string.IsNullOrWhiteSpace(resourceAccessValue))
if (!string.IsNullOrWhiteSpace(resourceAccessValue))
{
return Task.FromResult(result);
}

using var resourceAccess = JsonDocument.Parse(resourceAccessValue);
var containsAudienceRoles = resourceAccess.RootElement.TryGetProperty(
this.audience,
out var rolesElement
);

if (!containsAudienceRoles)
{
return Task.FromResult(result);
}

var clientRoles = rolesElement.GetProperty("roles");

foreach (var role in clientRoles.EnumerateArray())
{
var value = role.GetString();

var matchingClaim = identity.Claims.FirstOrDefault(claim =>
claim.Type.Equals(
this.roleClaimType,
StringComparison.InvariantCultureIgnoreCase
) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase)
using var resourceAccess = JsonDocument.Parse(resourceAccessValue);
var containsAudienceRoles = resourceAccess.RootElement.TryGetProperty(
this.audience,
out var rolesElement
);

if (matchingClaim is null && !string.IsNullOrWhiteSpace(value))
if (containsAudienceRoles)
{
identity.AddClaim(new Claim(this.roleClaimType, value));
var clientRoles = rolesElement.GetProperty("roles");

foreach (var role in clientRoles.EnumerateArray())
{
var value = role.GetString();

var matchingClaim = identity.Claims.FirstOrDefault(claim =>
claim.Type.Equals(
this.roleClaimType,
StringComparison.InvariantCultureIgnoreCase
) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase)
);

if (matchingClaim is null && !string.IsNullOrWhiteSpace(value))
{
identity.AddClaim(new Claim(this.roleClaimType, value));
}
}
}
}
}

if (this.roleSource.HasFlag(RolesClaimTransformationSource.Realm))
{
var realmAccessValue = principal.FindFirst("realm_access")?.Value;
if (string.IsNullOrWhiteSpace(realmAccessValue))
if (!string.IsNullOrWhiteSpace(realmAccessValue))
{
return Task.FromResult(result);
}

using var realmAccess = JsonDocument.Parse(realmAccessValue);
using var realmAccess = JsonDocument.Parse(realmAccessValue);

var containsRoles = realmAccess.RootElement.TryGetProperty(
"roles",
out var rolesElement
);
var containsRoles = realmAccess.RootElement.TryGetProperty(
"roles",
out var rolesElement
);

if (containsRoles)
{
foreach (var role in rolesElement.EnumerateArray())
if (containsRoles)
{
var value = role.GetString();

var matchingClaim = identity.Claims.FirstOrDefault(claim =>
claim.Type.Equals(
this.roleClaimType,
StringComparison.InvariantCultureIgnoreCase
) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase)
);

if (matchingClaim is null && !string.IsNullOrWhiteSpace(value))
foreach (var role in rolesElement.EnumerateArray())
{
identity.AddClaim(new Claim(this.roleClaimType, value));
var value = role.GetString();

var matchingClaim = identity.Claims.FirstOrDefault(claim =>
claim.Type.Equals(
this.roleClaimType,
StringComparison.InvariantCultureIgnoreCase
) && claim.Value.Equals(value, StringComparison.InvariantCultureIgnoreCase)
);

if (matchingClaim is null && !string.IsNullOrWhiteSpace(value))
{
identity.AddClaim(new Claim(this.roleClaimType, value));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,38 @@ public async Task ClaimsTransformationShouldHandleMissingResourceClaim()
claimsPrincipal.Claims.Count(item => ClaimTypes.Role == item.Type).Should().Be(0);
}

[Fact]
public async Task ClaimsTransformationShouldHandleMissingResourceClaimWithRealmRoles()
{
var target = new KeycloakRolesClaimsTransformation(
ClaimTypes.Role,
RolesClaimTransformationSource.All,
ClientId
);
var claimsPrincipal = GetClaimsPrincipal(MyRealmClaimValue, null);

claimsPrincipal = await target.TransformAsync(claimsPrincipal);
claimsPrincipal.HasClaim(ClaimTypes.Role, RealmRoleUserClaim).Should().BeTrue();
claimsPrincipal.HasClaim(ClaimTypes.Role, RealmRoleSuperUserClaim).Should().BeTrue();
claimsPrincipal.Claims.Count(item => ClaimTypes.Role == item.Type).Should().Be(2);
}

[Fact]
public async Task ClaimsTransformationShouldHandleMissingRealmClaimWithResourceRoles()
{
var target = new KeycloakRolesClaimsTransformation(
ClaimTypes.Role,
RolesClaimTransformationSource.All,
ClientId
);
var claimsPrincipal = GetClaimsPrincipal(null, MyResourceClaimValue);

claimsPrincipal = await target.TransformAsync(claimsPrincipal);
claimsPrincipal.HasClaim(ClaimTypes.Role, AppRoleUserClaim).Should().BeTrue();
claimsPrincipal.HasClaim(ClaimTypes.Role, AppRoleSuperUserClaim).Should().BeTrue();
claimsPrincipal.Claims.Count(item => ClaimTypes.Role == item.Type).Should().Be(2);
}

private const string MyResourceClaimValue = /*lang=json,strict*/
"""
{
Expand Down Expand Up @@ -138,17 +170,28 @@ public async Task ClaimsTransformationShouldHandleMissingResourceClaim()

// Get a claims principal that has all the appropriate claim details required for testing
private static ClaimsPrincipal GetClaimsPrincipal(
string realmClaimValue,
string resourceClaimValue
) =>
new(
new ClaimsIdentity(
[
new Claim(ResourceClaimType, resourceClaimValue, JsonValueType, MyUrl, MyUrl),
new Claim(RealmClaimType, realmClaimValue, JsonValueType, MyUrl, MyUrl),
]
)
);
string? realmClaimValue,
string? resourceClaimValue
)
{
var claimsIdentity = new ClaimsIdentity();

if (realmClaimValue != null)
{
claimsIdentity.AddClaim(
new Claim(RealmClaimType, realmClaimValue, JsonValueType, MyUrl, MyUrl)
);
}

if (resourceClaimValue != null)
{
claimsIdentity.AddClaim(
new Claim(ResourceClaimType, resourceClaimValue, JsonValueType, MyUrl, MyUrl)
);
}

return new ClaimsPrincipal(claimsIdentity);
}

// Get a claims principal that has all the appropriate claim details required for testing
private static ClaimsPrincipal GetClaimsPrincipalClaim(string claimValue) =>
Expand Down

0 comments on commit 018aaea

Please sign in to comment.