diff --git a/orcid-web/src/main/java/org/orcid/authorization/authentication/MFAAuthenticationDetailsSource.java b/orcid-web/src/main/java/org/orcid/authorization/authentication/MFAAuthenticationDetailsSource.java new file mode 100644 index 00000000000..3a08a9f6b08 --- /dev/null +++ b/orcid-web/src/main/java/org/orcid/authorization/authentication/MFAAuthenticationDetailsSource.java @@ -0,0 +1,15 @@ +package org.orcid.authorization.authentication; + +import javax.servlet.http.HttpServletRequest; + +import org.springframework.security.authentication.AuthenticationDetailsSource; +import org.springframework.security.web.authentication.WebAuthenticationDetails; + +public class MFAAuthenticationDetailsSource implements AuthenticationDetailsSource { + + @Override + public MFAWebAuthenticationDetails buildDetails(HttpServletRequest context) { + return new MFAWebAuthenticationDetails(context); + } + +} diff --git a/orcid-web/src/main/java/org/orcid/authorization/authentication/MFAWebAuthenticationDetails.java b/orcid-web/src/main/java/org/orcid/authorization/authentication/MFAWebAuthenticationDetails.java new file mode 100644 index 00000000000..1f52f588be4 --- /dev/null +++ b/orcid-web/src/main/java/org/orcid/authorization/authentication/MFAWebAuthenticationDetails.java @@ -0,0 +1,77 @@ +package org.orcid.authorization.authentication; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpSession; + +import org.apache.thrift.TSerializable; +import org.springframework.security.web.authentication.WebAuthenticationDetails; + +import java.io.Serializable; +import java.util.Objects; + +public class MFAWebAuthenticationDetails implements Serializable { + + public static final String VERIFICATION_CODE_PARAMETER = "verificationCode"; + + public static final String RECOVERY_CODE_PARAMETER = "recoveryCode"; + + private final String verificationCode; + + private final String recoveryCode; + + private final String remoteAddress; + + private final String sessionId; + + public MFAWebAuthenticationDetails(HttpServletRequest request) { + verificationCode = getParameterOrAttribute(request, VERIFICATION_CODE_PARAMETER); + recoveryCode = getParameterOrAttribute(request, RECOVERY_CODE_PARAMETER); + remoteAddress = request.getRemoteAddr(); + HttpSession session = request.getSession(false); + sessionId = session != null ? session.getId() : null; + } + + public MFAWebAuthenticationDetails(String remoteAddress, String sessionId, String verificationCode, String recoveryCode) { + this.verificationCode = verificationCode; + this.recoveryCode = recoveryCode; + this.remoteAddress = remoteAddress; + this.sessionId = sessionId; + } + + private String getParameterOrAttribute(HttpServletRequest request, String name) { + String value = request.getParameter(name); + if (value == null) { + value = (String) request.getAttribute(name); + } + return value; + } + + public String getVerificationCode() { + return verificationCode; + } + + public String getRecoveryCode() { + return recoveryCode; + } + + public String getRemoteAddress() { + return remoteAddress; + } + + public String getSessionId() { + return sessionId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MFAWebAuthenticationDetails that = (MFAWebAuthenticationDetails) o; + return Objects.equals(verificationCode, that.verificationCode) && Objects.equals(recoveryCode, that.recoveryCode) && Objects.equals(remoteAddress, that.remoteAddress) && Objects.equals(sessionId, that.sessionId); + } + + @Override + public int hashCode() { + return Objects.hash(verificationCode, recoveryCode, remoteAddress, sessionId); + } +} diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationDetailsSource.java b/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationDetailsSource.java deleted file mode 100644 index 42a76e93f59..00000000000 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationDetailsSource.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.orcid.frontend.spring; - -import javax.servlet.http.HttpServletRequest; - -import org.springframework.security.authentication.AuthenticationDetailsSource; -import org.springframework.security.web.authentication.WebAuthenticationDetails; - -public class OrcidAuthenticationDetailsSource implements AuthenticationDetailsSource { - - @Override - public WebAuthenticationDetails buildDetails(HttpServletRequest context) { - return new OrcidWebAuthenticationDetails(context); - } - -} diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationProvider.java b/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationProvider.java index 3f31f7bd7ae..7af8b615e7f 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationProvider.java +++ b/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidAuthenticationProvider.java @@ -2,12 +2,12 @@ import java.time.Instant; import java.util.Date; -import java.util.List; import javax.annotation.Resource; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; +import org.orcid.authorization.authentication.MFAWebAuthenticationDetails; import org.orcid.core.manager.BackupCodeManager; import org.orcid.core.manager.ProfileEntityCacheManager; import org.orcid.core.manager.TwoFactorAuthenticationManager; @@ -130,13 +130,13 @@ public Authentication authenticate(Authentication auth) throws AuthenticationExc } if (profile.getUsing2FA()) { - String recoveryCode = ((OrcidWebAuthenticationDetails) auth.getDetails()).getRecoveryCode(); + String recoveryCode = ((MFAWebAuthenticationDetails) auth.getDetails()).getRecoveryCode(); if (recoveryCode != null && !recoveryCode.isEmpty()) { if (!backupCodeManager.verify(profile.getId(), recoveryCode)) { throw new Bad2FARecoveryCodeException(); } } else { - String verificationCode = ((OrcidWebAuthenticationDetails) auth.getDetails()).getVerificationCode(); + String verificationCode = ((MFAWebAuthenticationDetails) auth.getDetails()).getVerificationCode(); if (verificationCode == null || verificationCode.isEmpty()) { throw new VerificationCodeFor2FARequiredException(); } diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidWebAuthenticationDetails.java b/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidWebAuthenticationDetails.java deleted file mode 100644 index 4520e325cc2..00000000000 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/OrcidWebAuthenticationDetails.java +++ /dev/null @@ -1,41 +0,0 @@ -package org.orcid.frontend.spring; - -import javax.servlet.http.HttpServletRequest; - -import org.springframework.security.web.authentication.WebAuthenticationDetails; - -public class OrcidWebAuthenticationDetails extends WebAuthenticationDetails { - - private static final long serialVersionUID = 1L; - - public static final String VERIFICATION_CODE_PARAMETER = "verificationCode"; - - public static final String RECOVERY_CODE_PARAMETER = "recoveryCode"; - - private String verificationCode; - - private String recoveryCode; - - public OrcidWebAuthenticationDetails(HttpServletRequest request) { - super(request); - verificationCode = getParameterOrAttribute(request, VERIFICATION_CODE_PARAMETER); - recoveryCode = getParameterOrAttribute(request, RECOVERY_CODE_PARAMETER); - } - - private String getParameterOrAttribute(HttpServletRequest request, String name) { - String value = request.getParameter(name); - if (value == null) { - value = (String) request.getAttribute(name); - } - return value; - } - - public String getVerificationCode() { - return verificationCode; - } - - public String getRecoveryCode() { - return recoveryCode; - } - -} diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidBeanClassLoaderAware.java b/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidBeanClassLoaderAware.java index 82db1224df2..262b5e1b70f 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidBeanClassLoaderAware.java +++ b/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidBeanClassLoaderAware.java @@ -3,6 +3,8 @@ import com.fasterxml.jackson.annotation.*; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import org.orcid.authorization.authentication.MFAWebAuthenticationDetails; +import org.orcid.frontend.web.util.MFAWebAuthenticationDetailsDeserializer; import org.orcid.frontend.web.util.SwitchUserGrantedAuthorityDeserializer; import org.springframework.beans.factory.BeanClassLoaderAware; import org.springframework.context.annotation.Bean; @@ -34,6 +36,7 @@ private ObjectMapper objectMapper() { mapper.registerModules(new CoreJackson2Module()); mapper.addMixIn(String[].class, StringArrayMixin.class); mapper.addMixIn(SwitchUserGrantedAuthority.class, SwitchUserGrantedAuthorityMixin.class); + mapper.addMixIn(MFAWebAuthenticationDetails.class, MFAWebAuthenticationDetailsMixin.class); return mapper; } @@ -67,4 +70,15 @@ public SwitchUserGrantedAuthorityMixin(@JsonProperty("authority") String role, @ } } + + @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) + @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, + getterVisibility = JsonAutoDetect.Visibility.NONE, isGetterVisibility = JsonAutoDetect.Visibility.NONE, creatorVisibility = JsonAutoDetect.Visibility.ANY) + @JsonDeserialize(using = MFAWebAuthenticationDetailsDeserializer.class) + @JsonIgnoreProperties(ignoreUnknown = true) + abstract class MFAWebAuthenticationDetailsMixin { + @JsonCreator + MFAWebAuthenticationDetailsMixin(@JsonProperty("remoteAddress") String remoteAddress, @JsonProperty("sessionId") String sessionId, @JsonProperty("verificationCode") String verificationCode, @JsonProperty("recoveryCode") String recoveryCode) { + } + } } diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidRequestCache.java b/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidRequestCache.java index 92b02eb19ad..1e23c7d598d 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidRequestCache.java +++ b/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/OrcidRequestCache.java @@ -1,12 +1,12 @@ package org.orcid.frontend.spring.configuration; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import org.orcid.core.manager.impl.OrcidUrlManager; import org.springframework.security.web.savedrequest.HttpSessionRequestCache; import org.springframework.stereotype.Service; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + @Service("orcidRequestCache") public class OrcidRequestCache extends HttpSessionRequestCache { @Override diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/SessionCacheConfig.java b/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/SessionCacheConfig.java index 269f97ccbe0..39c7e088c41 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/SessionCacheConfig.java +++ b/orcid-web/src/main/java/org/orcid/frontend/spring/configuration/SessionCacheConfig.java @@ -5,26 +5,14 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.connection.RedisStandaloneConfiguration; import org.springframework.data.redis.connection.jedis.JedisClientConfiguration; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; -import org.springframework.data.redis.core.RedisOperations; -import org.springframework.data.redis.core.RedisTemplate; -import org.springframework.data.redis.serializer.StringRedisSerializer; -import org.springframework.session.FlushMode; -import org.springframework.session.SaveMode; -import org.springframework.session.SessionRepository; -import org.springframework.session.config.annotation.web.http.EnableSpringHttpSession; -import org.springframework.session.data.redis.RedisSessionRepository; import org.springframework.session.data.redis.config.ConfigureRedisAction; -import org.springframework.session.data.redis.config.annotation.web.http.EnableRedisHttpSession; import org.springframework.session.web.context.AbstractHttpSessionApplicationInitializer; import org.springframework.session.web.http.CookieSerializer; import org.springframework.session.web.http.DefaultCookieSerializer; -import org.springframework.session.web.http.SessionRepositoryFilter; -import javax.servlet.ServletContext; import java.time.Duration; @Configuration diff --git a/orcid-web/src/main/java/org/orcid/frontend/spring/session/redis/OrcidRedisIndexedSessionRepository.java b/orcid-web/src/main/java/org/orcid/frontend/spring/session/redis/OrcidRedisIndexedSessionRepository.java index 57fcba0b08b..3c0039e42e2 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/spring/session/redis/OrcidRedisIndexedSessionRepository.java +++ b/orcid-web/src/main/java/org/orcid/frontend/spring/session/redis/OrcidRedisIndexedSessionRepository.java @@ -54,7 +54,7 @@ public class OrcidRedisIndexedSessionRepository implements FindByIndexNameSessio private FlushMode flushMode; private SaveMode saveMode; private final String PUBLIC_ORCID_PAGE_REGEX = "/(\\d{4}-){3,}\\d{3}[\\dX](/.+)"; - private final List urisToSkip = List.of("/2FA/status.json", "/account/", "/account/biographyForm.json", "/account/countryForm.json", "/account/delegates.json", "/account/emails.json", + private final List urisToSkipOnGet = List.of("/2FA/status.json", "/account/", "/account/biographyForm.json", "/account/countryForm.json", "/account/delegates.json", "/account/emails.json", "/account/get-trusted-orgs.json", "/account/nameForm.json", "/account/preferences.json", "/account/socialAccounts.json", "/affiliations/affiliationDetails.json", "/affiliations/affiliationGroups.json", "/assets/vectors/orcid.logo.icon.svg", "/config.json", "/delegators/delegators-and-me.json", "/fundings/fundingDetails.json", "/fundings/fundingGroups.json", "/inbox/notifications.json", "/inbox/totalCount.json", "/inbox/unreadCount.json", "/my-orcid/externalIdentifiers.json", "/my-orcid/keywordsForms.json", "/my-orcid/otherNamesForms.json", "/my-orcid/websitesForms.json", @@ -62,7 +62,9 @@ public class OrcidRedisIndexedSessionRepository implements FindByIndexNameSessio "/orgs/disambiguated/ROR", "/peer-reviews/peer-review.json", "/peer-reviews/peer-reviews-by-group-id.json", "/peer-reviews/peer-reviews-minimized.json", "/qr-code.png", "/research-resources/researchResource.json", "/research-resources/researchResourcePage.json", "/works/getWorkInfo.json", "/works/groupingSuggestions.json", "/works/idTypes.json", "/works/work.json", "/works/worksExtendedPage.json"); - private final Set SKIP_SAVE_SESSION = new HashSet<>(urisToSkip); + private final List urisToSkipAlways = List.of("/oauth/custom/register/validatePassword.json"); + private final Set GET_SKIP_SAVE_SESSION = new HashSet<>(urisToSkipOnGet); + private final Set ALWAYS_SKIP_SAVE_SESSION = new HashSet<>(urisToSkipAlways); public OrcidRedisIndexedSessionRepository(RedisOperations sessionRedisOperations) { this.flushMode = FlushMode.ON_SAVE; @@ -363,11 +365,10 @@ private BoundHashOperations getSessionBoundHashOperation private boolean updateSession() { ServletRequestAttributes att = (ServletRequestAttributes)RequestContextHolder.getRequestAttributes(); HttpServletRequest request = att.getRequest(); - if(request.getMethod().equals("GET")) { - String url = request.getRequestURI().substring(request.getContextPath().length()); - if(SKIP_SAVE_SESSION.contains(url) || url.matches(PUBLIC_ORCID_PAGE_REGEX)) { - return false; - } + String url = request.getRequestURI().substring(request.getContextPath().length()); + if((request.getMethod().equals("GET") && (GET_SKIP_SAVE_SESSION.contains(url) || url.matches(PUBLIC_ORCID_PAGE_REGEX))) + || ALWAYS_SKIP_SAVE_SESSION.contains(url)) { + return false; } return true; } @@ -409,7 +410,7 @@ public void setLastAccessedTime(Instant lastAccessedTime) { // TODO: REMOVE THIS BEFORE GOING LIVE!!!! ServletRequestAttributes att = (ServletRequestAttributes)RequestContextHolder.getRequestAttributes(); HttpServletRequest request = att.getRequest(); - System.out.println("REDIS_SESSION: setLastAccessedTime: " + request.getRequestURI().toString() + " - " + request.getMethod()); + logger.info("REDIS_SESSION: setLastAccessedTime: " + request.getRequestURI().toString() + " - " + request.getMethod()); this.cached.setLastAccessedTime(lastAccessedTime); this.delta.put("lastAccessedTime", this.getLastAccessedTime().toEpochMilli()); diff --git a/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthControllerBase.java b/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthControllerBase.java index b0e281dc1ca..e883d8f5a8c 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthControllerBase.java +++ b/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthControllerBase.java @@ -11,7 +11,7 @@ import org.orcid.core.manager.v3.read_only.RecordNameManagerReadOnly; import org.orcid.core.oauth.service.OrcidAuthorizationEndpoint; import org.orcid.core.oauth.service.OrcidOAuth2RequestValidator; -import org.orcid.frontend.spring.OrcidWebAuthenticationDetails; +import org.orcid.authorization.authentication.MFAWebAuthenticationDetails; import org.orcid.persistence.jpa.entities.ClientDetailsEntity; import org.orcid.pojo.ajaxForm.PojoUtil; import org.orcid.pojo.ajaxForm.RequestInfoForm; @@ -139,7 +139,7 @@ protected void copy(Map savedParams, Map param ****************************/ protected Authentication authenticateUser(HttpServletRequest request, String email, String password) { UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(email, password); - token.setDetails(new OrcidWebAuthenticationDetails(request)); + token.setDetails(new MFAWebAuthenticationDetails(request)); Authentication authentication = authenticationManager.authenticate(token); SecurityContextHolder.getContext().setAuthentication(authentication); return authentication; diff --git a/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthLoginController.java b/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthLoginController.java index 6ee7ba3d939..5b7914028ad 100644 --- a/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthLoginController.java +++ b/orcid-web/src/main/java/org/orcid/frontend/web/controllers/OauthLoginController.java @@ -15,7 +15,7 @@ import org.orcid.core.manager.v3.ProfileEntityManager; import org.orcid.core.security.UnclaimedProfileExistsException; import org.orcid.core.utils.OrcidRequestUtil; -import org.orcid.frontend.spring.OrcidWebAuthenticationDetails; +import org.orcid.authorization.authentication.MFAWebAuthenticationDetails; import org.orcid.frontend.web.controllers.helper.OauthHelper; import org.orcid.frontend.web.exception.Bad2FARecoveryCodeException; import org.orcid.frontend.web.exception.Bad2FAVerificationCodeException; @@ -195,11 +195,11 @@ public ModelAndView loginGetHandler(HttpServletRequest request, HttpServletRespo private void copy2FAFields(OauthAuthorizeForm form, HttpServletRequest request) { if (form.getVerificationCode() != null) { - request.setAttribute(OrcidWebAuthenticationDetails.VERIFICATION_CODE_PARAMETER, form.getVerificationCode().getValue()); + request.setAttribute(MFAWebAuthenticationDetails.VERIFICATION_CODE_PARAMETER, form.getVerificationCode().getValue()); } if (form.getRecoveryCode() != null) { - request.setAttribute(OrcidWebAuthenticationDetails.RECOVERY_CODE_PARAMETER, form.getRecoveryCode().getValue()); + request.setAttribute(MFAWebAuthenticationDetails.RECOVERY_CODE_PARAMETER, form.getRecoveryCode().getValue()); } } diff --git a/orcid-web/src/main/java/org/orcid/frontend/web/util/MFAWebAuthenticationDetailsDeserializer.java b/orcid-web/src/main/java/org/orcid/frontend/web/util/MFAWebAuthenticationDetailsDeserializer.java new file mode 100644 index 00000000000..71cd329e549 --- /dev/null +++ b/orcid-web/src/main/java/org/orcid/frontend/web/util/MFAWebAuthenticationDetailsDeserializer.java @@ -0,0 +1,30 @@ +package org.orcid.frontend.web.util; + +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.orcid.authorization.authentication.MFAWebAuthenticationDetails; + +import java.io.IOException; + +public class MFAWebAuthenticationDetailsDeserializer extends JsonDeserializer { + @Override + public MFAWebAuthenticationDetails deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JacksonException { + ObjectMapper mapper = (ObjectMapper) jsonParser.getCodec(); + JsonNode jsonNode = mapper.readTree(jsonParser); + JsonNode verificationCodeNode = jsonNode.get("verificationCode"); + JsonNode recoveryCodeNode = jsonNode.get("recoveryCode"); + JsonNode remoteAddressNode = jsonNode.get("remoteAddress"); + JsonNode sessionIdNode = jsonNode.get("sessionId"); + + String verificationCode = (verificationCodeNode != null && verificationCodeNode.isTextual()) ? verificationCodeNode.asText() : null; + String recoveryCode = (recoveryCodeNode != null && recoveryCodeNode.isTextual()) ? recoveryCodeNode.asText() : null; + String remoteAddress = (remoteAddressNode != null && remoteAddressNode.isTextual()) ? remoteAddressNode.asText() : null; + String sessionId = (sessionIdNode != null && sessionIdNode.isTextual()) ? sessionIdNode.asText() : null; + + return new MFAWebAuthenticationDetails(remoteAddress, sessionId, verificationCode, recoveryCode); + } +} diff --git a/orcid-web/src/main/resources/orcid-frontend-security.xml b/orcid-web/src/main/resources/orcid-frontend-security.xml index 1b99b724eb6..ab746ee0b82 100644 --- a/orcid-web/src/main/resources/orcid-frontend-security.xml +++ b/orcid-web/src/main/resources/orcid-frontend-security.xml @@ -69,7 +69,7 @@ class="org.orcid.frontend.spring.AjaxAuthenticationFailureHandler" /> + class="org.orcid.authorization.authentication.MFAAuthenticationDetailsSource" />