Skip to content

Commit

Permalink
Allow setting origin in token provider
Browse files Browse the repository at this point in the history
  • Loading branch information
cjmalloy committed Sep 24, 2023
1 parent 6979a6a commit f4fbb24
Show file tree
Hide file tree
Showing 17 changed files with 185 additions and 159 deletions.
2 changes: 1 addition & 1 deletion src/main/java/jasper/config/SecurityConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public void configure(HttpSecurity http) throws Exception {

private JWTConfigurer securityConfigurerAdapter() {
logger.info("Minimum Role: {}", props.getMinRole());
return new JWTConfigurer(tokenProvider, defaultTokenProvider);
return new JWTConfigurer(props, tokenProvider, defaultTokenProvider);
}

@Bean
Expand Down
50 changes: 35 additions & 15 deletions src/main/java/jasper/config/WebSocketConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;

import javax.servlet.http.HttpServletRequest;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;

import static jasper.security.Auth.LOCAL_ORIGIN_HEADER;
import static org.apache.commons.lang3.StringUtils.isNotBlank;

@Profile("!no-websocket")
Expand All @@ -47,6 +49,9 @@ public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {

public static final String AUTHORIZATION_HEADER = "Authorization";

@Autowired
Props props;

@Autowired
TokenProvider tokenProvider;

Expand Down Expand Up @@ -79,11 +84,11 @@ public void configureMessageBroker(MessageBrokerRegistry config) {
public void registerStompEndpoints(StompEndpointRegistry registry) {
registry
.addEndpoint("/")
.setHandshakeHandler(new StompDefaultHandshakeHandler())
.addInterceptors(new StompHandshakeInterceptor())
.setHandshakeHandler(new StompDefaultHandshakeHandler())
.addInterceptors(new StompHandshakeInterceptor())
.setAllowedOriginPatterns("*")
.withSockJS()
.setSuppressCors(true);
.withSockJS()
.setSuppressCors(true);
}

@Override
Expand All @@ -92,14 +97,27 @@ public void configureClientInboundChannel(ChannelRegistration registration) {
}

class StompHandshakeInterceptor implements HandshakeInterceptor {

private String resolveOrigin(HttpServletRequest request) {
var origin = props.getLocalOrigin();
var headerOrigin = request.getHeader(LOCAL_ORIGIN_HEADER);
logger.debug("STOMP Local Origin Header: {}", headerOrigin);
if (props.isAllowLocalOriginHeader() && headerOrigin != null) {
return headerOrigin.toLowerCase();
}
return origin;
}

@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) {
logger.debug("STOMP Handshake");
if (request instanceof ServletServerHttpRequest servletRequest) {
var httpServletRequest = servletRequest.getServletRequest();
var token = httpServletRequest.getHeader(AUTHORIZATION_HEADER);
if (isNotBlank(token)) {
attributes.put("jwt", token.substring("Bearer ".length()));
}
attributes.put("origin", resolveOrigin(httpServletRequest));
}
return true;
}
Expand All @@ -111,11 +129,11 @@ public void afterHandshake(ServerHttpRequest request, ServerHttpResponse respons
class StompDefaultHandshakeHandler extends DefaultHandshakeHandler {
@Override
public Principal determineUser(ServerHttpRequest request, WebSocketHandler handler, Map<String, Object> attributes) {
if (!attributes.containsKey("jwt")) return request.getPrincipal();
logger.debug("STOMP Request Principal: " + request.getPrincipal());
var origin = (String) attributes.get("origin");
logger.debug("STOMP Determine User: " + origin);
if (!attributes.containsKey("jwt")) return defaultTokenProvider.getAuthentication(null, origin);
var token = (String) attributes.get("jwt");
TokenProvider t = tokenProvider.validateToken(token) ? tokenProvider : defaultTokenProvider;
return t.getAuthentication(token);
return tokenProvider.validateToken(token, origin) ? tokenProvider.getAuthentication(token, origin) : defaultTokenProvider.getAuthentication(null, origin);
}
}

Expand All @@ -127,17 +145,19 @@ public Message<?> preSend(Message<?> message, MessageChannel channel) {
if (accessor.getCommand() == StompCommand.BEGIN) return null; // No Transactions
if (accessor.getCommand() == StompCommand.SEND) return null; // No Client Messages
if (accessor.getCommand() != StompCommand.SUBSCRIBE) return message;
var headers = message.getHeaders().get("nativeHeaders", Map.class);
var token = ((ArrayList<String>) headers.get("jwt")).get(0);
if (tokenProvider.validateToken(token)) {
logger.debug("STOMP SUBSCRIBE Credentials Header");
auth.clear(tokenProvider.getAuthentication(token));
} else if (accessor.getUser() instanceof Authentication authentication) {
if (accessor.getUser() instanceof Authentication authentication) {
logger.debug("STOMP User Set");
auth.clear(authentication);
var headers = message.getHeaders().get("nativeHeaders", Map.class);
var token = ((ArrayList<String>) headers.get("jwt")).get(0);
var origin = auth.getOrigin();
if (tokenProvider.validateToken(token, origin)) {
logger.debug("STOMP SUBSCRIBE Credentials Header");
auth.clear(tokenProvider.getAuthentication(token, origin));
}
} else {
logger.debug("STOMP Default auth");
auth.clear(defaultTokenProvider.getAuthentication(null));
auth.clear(defaultTokenProvider.getAuthentication(null, props.getLocalOrigin()));
}
if (auth.canSubscribeTo(accessor.getDestination())) return message;
logger.error("{} can't subscribe to {}", auth.getUserTag(), accessor.getDestination());
Expand Down
16 changes: 10 additions & 6 deletions src/main/java/jasper/security/Auth.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ public void clear(Authentication authentication) {
writeAccess = null;
tagReadAccess = null;
tagWriteAccess = null;
origin = qt(getPrincipal()).origin;
if (getPrincipal().startsWith("@")) {
origin = getPrincipal();
} else {
origin = qt(getPrincipal()).origin;
}
}

@PostConstruct
Expand Down Expand Up @@ -221,7 +225,7 @@ public boolean freshLogin() {
* a default user tag +user.
*/
public boolean isLoggedIn() {
return isNotBlank(getPrincipal());
return isNotBlank(getPrincipal()) && !getPrincipal().startsWith("@");
}

/**
Expand Down Expand Up @@ -734,8 +738,8 @@ public List<QualifiedTag> getReadAccess() {
if (getClient().isAllowAuthHeaders()) {
readAccess.addAll(getHeaderQualifiedTags(READ_ACCESS_HEADER));
}
readAccess.addAll(getClaimQualifiedTags(getClient().getReadAccessClaim()));
if (isLoggedIn()) {
readAccess.addAll(getClaimQualifiedTags(getClient().getReadAccessClaim()));
readAccess.addAll(selectors(getMultiTenantOrigin(), getUser()
.map(User::getReadAccess)
.orElse(List.of())));
Expand All @@ -753,8 +757,8 @@ public List<QualifiedTag> getWriteAccess() {
if (getClient().isAllowAuthHeaders()) {
writeAccess.addAll(getHeaderQualifiedTags(WRITE_ACCESS_HEADER));
}
writeAccess.addAll(getClaimQualifiedTags(getClient().getWriteAccessClaim()));
if (isLoggedIn()) {
writeAccess.addAll(getClaimQualifiedTags(getClient().getWriteAccessClaim()));
writeAccess.addAll(selectors(getMultiTenantOrigin(), getUser()
.map(User::getWriteAccess)
.orElse(List.of())));
Expand All @@ -772,8 +776,8 @@ public List<QualifiedTag> getTagReadAccess() {
if (getClient().isAllowAuthHeaders()) {
tagReadAccess.addAll(getHeaderQualifiedTags(TAG_READ_ACCESS_HEADER));
}
tagReadAccess.addAll(getClaimQualifiedTags(getClient().getTagReadAccessClaim()));
if (isLoggedIn()) {
tagReadAccess.addAll(getClaimQualifiedTags(getClient().getTagReadAccessClaim()));
tagReadAccess.addAll(selectors(getMultiTenantOrigin(), getUser()
.map(User::getTagReadAccess)
.orElse(List.of())));
Expand All @@ -791,8 +795,8 @@ public List<QualifiedTag> getTagWriteAccess() {
if (getClient().isAllowAuthHeaders()) {
tagWriteAccess.addAll(getHeaderQualifiedTags(TAG_WRITE_ACCESS_HEADER));
}
tagWriteAccess.addAll(getClaimQualifiedTags(getClient().getTagWriteAccessClaim()));
if (isLoggedIn()) {
tagWriteAccess.addAll(getClaimQualifiedTags(getClient().getTagWriteAccessClaim()));
tagWriteAccess.addAll(selectors(getMultiTenantOrigin(), getUser()
.map(User::getTagWriteAccess)
.orElse(List.of())));
Expand Down
23 changes: 11 additions & 12 deletions src/main/java/jasper/security/jwt/AbstractJwtTokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ public abstract class AbstractJwtTokenProvider extends AbstractTokenProvider imp
this.securityMetersService = securityMetersService;
}

Collection<? extends GrantedAuthority> getAuthorities(Claims claims, User user) {
var auth = getPartialAuthorities(claims);
Collection<? extends GrantedAuthority> getAuthorities(Claims claims, User user, String origin) {
var auth = getPartialAuthorities(claims, origin);
if (user != null && user.getRole() != null) {
logger.debug("User Roles: {}", user.getRole());
if (User.ROLES.contains(user.getRole().trim())) {
Expand All @@ -64,9 +64,9 @@ Collection<? extends GrantedAuthority> getAuthorities(Claims claims, User user)
return auth;
}

List<SimpleGrantedAuthority> getPartialAuthorities(Claims claims) {
var auth = getPartialAuthorities();
var client = props.getSecurity().getClient(getPartialOrigin());
List<SimpleGrantedAuthority> getPartialAuthorities(Claims claims, String origin) {
var auth = getPartialAuthorities(origin);
var client = props.getSecurity().getClient(origin);
var authClaim = claims.get(client.getAuthoritiesClaim(), String.class);
if (isNotBlank(authClaim)) {
Arrays.stream(authClaim.split(","))
Expand All @@ -78,15 +78,14 @@ List<SimpleGrantedAuthority> getPartialAuthorities(Claims claims) {
return auth;
}

String getUsername(Claims claims) {
var client = props.getSecurity().getClient(getPartialOrigin());
String getUsername(Claims claims, String origin) {
var client = props.getSecurity().getClient(origin);
if (client.isAllowUserTagHeader() && !isBlank(getHeader(USER_TAG_HEADER))) {
return getHeader(USER_TAG_HEADER);
}
logger.debug("Sub: {}", client.getUsernameClaim());
var principal = claims.get(client.getUsernameClaim(), String.class);
logger.debug("Principal: {}", principal);
var origin = props.getLocalOrigin();
if (props.isAllowLocalOriginHeader() && getOriginHeader() != null) {
origin = getOriginHeader().toLowerCase();
} else if (!isBlank(principal) && client.isAllowUsernameClaimOrigin() && principal.contains("@")) {
Expand All @@ -101,7 +100,7 @@ String getUsername(Claims claims) {
if (principal.contains("@")) {
principal = principal.substring(0, principal.indexOf("@"));
}
var authorities = getPartialAuthorities(claims);
var authorities = getPartialAuthorities(claims, origin);
if (isBlank(principal) ||
!principal.matches(Tag.QTAG_REGEX) ||
principal.equals("+user") ||
Expand All @@ -127,11 +126,11 @@ String getUsername(Claims claims) {
}

@Override
public boolean validateToken(String authToken) {
public boolean validateToken(String authToken, String origin) {
if (!StringUtils.hasText(authToken)) return false;
var client = props.getSecurity().getClient(getPartialOrigin());
var client = props.getSecurity().getClient(origin);
try {
var claims = jwtParser.get(getPartialOrigin()).parseClaimsJws(authToken).getBody();
var claims = jwtParser.get(origin).parseClaimsJws(authToken).getBody();
if (!client.getAuthentication().getJwt().getClientId().equals(claims.getAudience())) {
this.securityMetersService.trackTokenInvalidAudience();
logger.trace(INVALID_JWT_TOKEN + " Invalid Audience");
Expand Down
19 changes: 5 additions & 14 deletions src/main/java/jasper/security/jwt/AbstractTokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import static jasper.security.Auth.USER_ROLE_HEADER;
import static jasper.security.Auth.getHeader;
import static jasper.security.Auth.getOriginHeader;
import static org.apache.commons.lang3.StringUtils.isNotBlank;

public abstract class AbstractTokenProvider implements TokenProvider {
Expand All @@ -35,8 +34,8 @@ User getUser(String userTag) {
return userDetailsProvider.findOneByQualifiedTag(userTag).orElse(null);
}

Collection<? extends GrantedAuthority> getAuthorities(User user) {
var auth = getPartialAuthorities();
Collection<? extends GrantedAuthority> getAuthorities(User user, String origin) {
var auth = getPartialAuthorities(origin);
if (user != null && user.getRole() != null) {
logger.debug("User Roles: {}", user.getRole());
if (User.ROLES.contains(user.getRole().trim())) {
Expand All @@ -48,10 +47,10 @@ Collection<? extends GrantedAuthority> getAuthorities(User user) {
return auth;
}

List<SimpleGrantedAuthority> getPartialAuthorities() {
var client = props.getSecurity().getClient(getPartialOrigin());
List<SimpleGrantedAuthority> getPartialAuthorities(String origin) {
var client = props.getSecurity().getClient(origin);
var authString = client == null ? "ROLE_ANONYMOUS" : client.getDefaultRole();
if (props.getSecurity().getClient(getPartialOrigin()).isAllowUserRoleHeader() && isNotBlank(getHeader(USER_ROLE_HEADER))) {
if (props.getSecurity().getClient(origin).isAllowUserRoleHeader() && isNotBlank(getHeader(USER_ROLE_HEADER))) {
logger.debug("Header Roles: {}", getHeader(USER_ROLE_HEADER));
authString += ", " + getHeader(USER_ROLE_HEADER);
}
Expand All @@ -62,12 +61,4 @@ List<SimpleGrantedAuthority> getPartialAuthorities() {
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toList());
}

public String getPartialOrigin() {
var origin = props.getLocalOrigin();
if (props.isAllowLocalOriginHeader() && getOriginHeader() != null) {
origin = getOriginHeader().toLowerCase();
}
return origin;
}
}
6 changes: 4 additions & 2 deletions src/main/java/jasper/security/jwt/JWTConfigurer.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

public class JWTConfigurer extends SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity> {

private final Props props;
private final TokenProvider tokenProvider;
private final TokenProviderImplDefault defaultTokenProvider;

public JWTConfigurer(TokenProvider tokenProvider, TokenProviderImplDefault defaultTokenProvider) {
public JWTConfigurer(Props props, TokenProvider tokenProvider, TokenProviderImplDefault defaultTokenProvider) {
this.props = props;
this.tokenProvider = tokenProvider;
this.defaultTokenProvider = defaultTokenProvider;
}

@Override
public void configure(HttpSecurity http) {
http.addFilterBefore(new JWTFilter(tokenProvider, defaultTokenProvider), UsernamePasswordAuthenticationFilter.class);
http.addFilterBefore(new JWTFilter(props, tokenProvider, defaultTokenProvider), UsernamePasswordAuthenticationFilter.class);
}
}
33 changes: 23 additions & 10 deletions src/main/java/jasper/security/jwt/JWTFilter.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package jasper.security.jwt;

import jasper.config.Props;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.GenericFilterBean;
Expand All @@ -14,6 +14,8 @@
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

import static jasper.security.Auth.LOCAL_ORIGIN_HEADER;

/**
* Filters incoming requests and installs a Spring Security principal if a header corresponding to a valid user is
* found.
Expand All @@ -23,32 +25,43 @@ public class JWTFilter extends GenericFilterBean {

public static final String AUTHORIZATION_HEADER = "Authorization";

private final Props props;
private final TokenProvider tokenProvider;
private final TokenProviderImplDefault defaultTokenProvider;

public JWTFilter(TokenProvider tokenProvider, TokenProviderImplDefault defaultTokenProvider) {
public JWTFilter(Props props, TokenProvider tokenProvider, TokenProviderImplDefault defaultTokenProvider) {
this.props = props;
this.tokenProvider = tokenProvider;
this.defaultTokenProvider = defaultTokenProvider;
}

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
String jwt = resolveToken(httpServletRequest);
if (tokenProvider.validateToken(jwt)) {
SecurityContextHolder.getContext().setAuthentication(tokenProvider.getAuthentication(jwt));
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
var httpServletRequest = (HttpServletRequest) servletRequest;
var origin = resolveOrigin(httpServletRequest);
var jwt = resolveToken(httpServletRequest);
if (tokenProvider.validateToken(jwt, origin)) {
SecurityContextHolder.getContext().setAuthentication(tokenProvider.getAuthentication(jwt, origin));
} else {
SecurityContextHolder.getContext().setAuthentication(defaultTokenProvider.getAuthentication(null));
SecurityContextHolder.getContext().setAuthentication(defaultTokenProvider.getAuthentication(null, origin));
}
filterChain.doFilter(servletRequest, servletResponse);
}

private String resolveToken(HttpServletRequest request) {
String bearerToken = request.getHeader(AUTHORIZATION_HEADER);
var bearerToken = request.getHeader(AUTHORIZATION_HEADER);
if (StringUtils.hasText(bearerToken) && bearerToken.startsWith("Bearer ")) {
return bearerToken.substring(7);
}
return null;
}

private String resolveOrigin(HttpServletRequest request) {
var origin = props.getLocalOrigin();
var headerOrigin = request.getHeader(LOCAL_ORIGIN_HEADER);
if (props.isAllowLocalOriginHeader() && headerOrigin != null) {
return headerOrigin.toLowerCase();
}
return origin;
}
}
5 changes: 2 additions & 3 deletions src/main/java/jasper/security/jwt/TokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import org.springframework.security.core.Authentication;

public interface TokenProvider {
boolean validateToken(String jwt);
Authentication getAuthentication(String jwt);
String getPartialOrigin();
boolean validateToken(String jwt, String origin);
Authentication getAuthentication(String jwt, String origin);
}
Loading

0 comments on commit f4fbb24

Please sign in to comment.