diff --git a/src/main/java/gov/cdc/izgateway/security/IzgPrincipal.java b/src/main/java/gov/cdc/izgateway/security/IzgPrincipal.java index 84961be..8370a08 100644 --- a/src/main/java/gov/cdc/izgateway/security/IzgPrincipal.java +++ b/src/main/java/gov/cdc/izgateway/security/IzgPrincipal.java @@ -5,6 +5,8 @@ import java.util.Date; import java.util.List; import java.util.Set; +import java.util.TreeSet; +import java.util.ArrayList; @Data public abstract class IzgPrincipal implements java.security.Principal { @@ -14,9 +16,9 @@ public abstract class IzgPrincipal implements java.security.Principal { Date validTo; String serialNumber; String issuer; - List audience; - Set scopes; - Set roles; + List audience = new ArrayList<>(); + Set scopes = new TreeSet<>(); + Set roles = new TreeSet<>(); public abstract String getSerialNumberHex(); } diff --git a/src/main/java/gov/cdc/izgateway/security/JWTPrincipal.java b/src/main/java/gov/cdc/izgateway/security/JWTPrincipal.java index 9f4f880..a273c6a 100644 --- a/src/main/java/gov/cdc/izgateway/security/JWTPrincipal.java +++ b/src/main/java/gov/cdc/izgateway/security/JWTPrincipal.java @@ -1,12 +1,9 @@ package gov.cdc.izgateway.security; -import lombok.Data; - import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.UUID; -@Data public class JWTPrincipal extends IzgPrincipal { public String getSerialNumberHex() { diff --git a/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java b/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java index 480def0..3cce765 100644 --- a/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java +++ b/src/main/java/gov/cdc/izgateway/security/principal/CertificatePrincipalProviderImpl.java @@ -24,8 +24,6 @@ public IzgPrincipal createPrincipalFromCertificate(HttpServletRequest request) { X509Certificate[] certs = (X509Certificate[]) request.getAttribute(Globals.CERTIFICATES_ATTR); if (certs == null || certs.length == 0) { - // Removing the warning log message as it is causing tests to fail - // log.warn("No certificates found in request."); return null; } diff --git a/src/main/java/gov/cdc/izgateway/security/principal/GroupToRoleMapper.java b/src/main/java/gov/cdc/izgateway/security/principal/GroupToRoleMapper.java new file mode 100644 index 0000000..09e22cb --- /dev/null +++ b/src/main/java/gov/cdc/izgateway/security/principal/GroupToRoleMapper.java @@ -0,0 +1,7 @@ +package gov.cdc.izgateway.security.principal; + +import java.util.Set; + +public interface GroupToRoleMapper { + Set mapGroupsToRoles(Set groups); +} diff --git a/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java b/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java index 647593d..c9299d5 100644 --- a/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java +++ b/src/main/java/gov/cdc/izgateway/security/principal/JwtSharedSecretPrincipalProvider.java @@ -9,11 +9,15 @@ import jakarta.servlet.http.HttpServletRequest; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; +import org.springframework.lang.Nullable; import org.springframework.stereotype.Component; import javax.crypto.SecretKey; import java.util.Collections; +import java.util.List; +import java.util.Set; import java.util.TreeSet; @Slf4j @@ -22,6 +26,16 @@ public class JwtSharedSecretPrincipalProvider implements JwtPrincipalProvider { @Value("${jwt.shared-secret:}") private String sharedSecret; + private final GroupToRoleMapper groupToRoleMapper; + private final ScopeToRoleMapper scopeToRoleMapper; + + @Autowired + public JwtSharedSecretPrincipalProvider(@Nullable GroupToRoleMapper groupToRoleMapper, + @Nullable ScopeToRoleMapper scopeToRoleMapper) { + this.groupToRoleMapper = groupToRoleMapper; + this.scopeToRoleMapper = scopeToRoleMapper; + } + @Override public IzgPrincipal createPrincipalFromJwt(HttpServletRequest request) { if (StringUtils.isBlank(sharedSecret)) { @@ -64,13 +78,23 @@ private IzgPrincipal buildPrincipal(Claims claims) { principal.setSerialNumber(claims.getId()); principal.setIssuer(claims.getIssuer()); principal.setAudience(Collections.singletonList(claims.getAudience())); - - TreeSet scopes = extractScopes(claims); - principal.setRoles(RoleMapper.mapScopesToRoles(scopes)); + addRolesFromScopes(claims, principal); + addRolesFromGroups(claims, principal); + log.debug("Roles created from JWT: {}", principal.getRoles()); return principal; } + private void addRolesFromScopes(Claims claims, IzgPrincipal principal) { + if (scopeToRoleMapper == null) { + log.debug("No scope to role mapper was set. Skipping scope to role mapping."); + return; + } + + TreeSet scopes = extractScopes(claims); + principal.getRoles().addAll(scopeToRoleMapper.mapScopesToRoles(scopes)); + } + private TreeSet extractScopes(Claims claims) { TreeSet scopes = new TreeSet<>(); String scopeString = claims.get("scope", String.class); @@ -79,4 +103,20 @@ private TreeSet extractScopes(Claims claims) { } return scopes; } + + private void addRolesFromGroups(Claims claims, IzgPrincipal principal) { + if (groupToRoleMapper == null) { + log.debug("No group to role mapper was set. Skipping group to role mapping."); + return; + } + + List groupsList = claims.get("groups", List.class); + if (groupsList == null || groupsList.isEmpty()) { + return; + } + Set groups = new TreeSet<>(groupsList); + + Set roles = groupToRoleMapper.mapGroupsToRoles(groups); + principal.getRoles().addAll(roles); + } } diff --git a/src/main/java/gov/cdc/izgateway/security/principal/ScopeToRoleMapper.java b/src/main/java/gov/cdc/izgateway/security/principal/ScopeToRoleMapper.java new file mode 100644 index 0000000..b48dc92 --- /dev/null +++ b/src/main/java/gov/cdc/izgateway/security/principal/ScopeToRoleMapper.java @@ -0,0 +1,7 @@ +package gov.cdc.izgateway.security.principal; + +import java.util.Set; + +public interface ScopeToRoleMapper { + Set mapScopesToRoles(Set scopes); +} diff --git a/src/main/java/gov/cdc/izgateway/security/principal/RoleMapper.java b/src/main/java/gov/cdc/izgateway/security/principal/ScopeToRoleMapperImpl.java similarity index 56% rename from src/main/java/gov/cdc/izgateway/security/principal/RoleMapper.java rename to src/main/java/gov/cdc/izgateway/security/principal/ScopeToRoleMapperImpl.java index bb9518f..bbb895e 100644 --- a/src/main/java/gov/cdc/izgateway/security/principal/RoleMapper.java +++ b/src/main/java/gov/cdc/izgateway/security/principal/ScopeToRoleMapperImpl.java @@ -1,12 +1,14 @@ package gov.cdc.izgateway.security.principal; +import org.springframework.stereotype.Component; + import java.util.Set; import java.util.TreeSet; -public class RoleMapper { - public static Set mapScopesToRoles(Set scopes) { +@Component +public class ScopeToRoleMapperImpl implements ScopeToRoleMapper { + public Set mapScopesToRoles(Set scopes) { // Until we've defined a mapping between scopes and roles, we'll just return the scopes as roles return new TreeSet<>(scopes); } - }