Skip to content

Commit

Permalink
XWIKI-22716: The WebSocket context assumes the handshake request and …
Browse files Browse the repository at this point in the history
…response objects can be used after the handshake is performed
  • Loading branch information
mflorea committed Dec 5, 2024
1 parent 69fe14a commit 924e3c6
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.xwiki.websocket.internal;

import java.net.HttpCookie;
import java.net.URI;
import java.security.Principal;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
Expand All @@ -32,6 +33,8 @@
import java.util.Optional;
import java.util.TreeMap;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xwiki.container.servlet.HttpServletRequestStub;

import jakarta.servlet.http.Cookie;
Expand All @@ -46,8 +49,16 @@
*/
public class XWikiWebSocketRequestStub extends HttpServletRequestStub
{
private static final Logger LOGGER = LoggerFactory.getLogger(XWikiWebSocketRequestStub.class);

private final HandshakeRequest request;

private final URI requestURI;

private final String queryString;

private final Principal userPrincipal;

/**
* Creates a new XWiki request that wraps the given WebSocket handshake request.
*
Expand All @@ -58,6 +69,9 @@ public XWikiWebSocketRequestStub(HandshakeRequest request)
super(buildFromHandshakeRequest(request));

this.request = request;
this.requestURI = request.getRequestURI();
this.queryString = request.getQueryString();
this.userPrincipal = request.getUserPrincipal();
}

private static Builder buildFromHandshakeRequest(HandshakeRequest request)
Expand All @@ -66,6 +80,7 @@ private static Builder buildFromHandshakeRequest(HandshakeRequest request)
Optional<String> cookieHeader = headers.getOrDefault("Cookie", Collections.emptyList()).stream().findFirst();
return new Builder().setRequestParameters(adaptParameterMap(request.getParameterMap()))
.setCookies(parseCookies(cookieHeader)).setHeaders(headers)
.setHttpSession((HttpSession) request.getHttpSession())
// The WebSocket API (JSR-356) doesn't expose the client IP address but at least we can avoid a null pointer
// exception.
.setRemoteAddr("");
Expand Down Expand Up @@ -110,7 +125,7 @@ public String getMethod()
@Override
public String getRequestURI()
{
return this.request.getRequestURI().toString();
return this.requestURI.toString();
}

private static Map<String, String[]> adaptParameterMap(Map<String, List<String>> params)
Expand All @@ -122,18 +137,6 @@ private static Map<String, String[]> adaptParameterMap(Map<String, List<String>>
return parameters;
}

@Override
public HttpSession getSession()
{
return getSession(true);
}

@Override
public HttpSession getSession(boolean create)
{
return (HttpSession) this.request.getHttpSession();
}

@Override
public String getServletPath()
{
Expand All @@ -143,30 +146,37 @@ public String getServletPath()
@Override
public String getPathInfo()
{
return this.request.getRequestURI().getPath();
return this.requestURI.getPath();
}

@Override
public String getScheme()
{
return this.request.getRequestURI().getScheme();
return this.requestURI.getScheme();
}

@Override
public String getQueryString()
{
return this.request.getQueryString();
return this.queryString;
}

@Override
public Principal getUserPrincipal()
{
return this.request.getUserPrincipal();
return this.userPrincipal;
}

@Override
public boolean isUserInRole(String role)
{
return this.request.isUserInRole(role);
try {
return this.request.isUserInRole(role);
} catch (Exception e) {
LOGGER.debug("Failed to determine if the currently authenticated user has the specified role. "
+ "This can happen if this method is called outside the WebSocket handshake request, "
+ "i.e. from a WebSocket end-point.", e);
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import jakarta.servlet.http.Cookie;
import jakarta.websocket.HandshakeResponse;

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xwiki.container.servlet.HttpServletResponseStub;

/**
Expand All @@ -46,6 +49,8 @@
*/
public class XWikiWebSocketResponseStub extends HttpServletResponseStub
{
private static final Logger LOGGER = LoggerFactory.getLogger(XWikiWebSocketResponseStub.class);

private final HandshakeResponse response;

/**
Expand All @@ -61,58 +66,67 @@ public XWikiWebSocketResponseStub(HandshakeResponse response)
@Override
public void addHeader(String name, String value)
{
List<String> values = getHeaderValues(name);
if (values == null) {
values = new ArrayList<>();
this.response.getHeaders().put(name, values);
}
List<String> values = getHeaderValues(name).orElseGet(() -> {
List<String> emptyValues = new ArrayList<>();
getHeaders().put(name, emptyValues);
return emptyValues;
});
values.add(value);
}

@Override
public boolean containsHeader(String name)
{
List<String> values = getHeaderValues(name);
return values != null && !values.isEmpty();
return getHeaderValues(name).map(values -> !values.isEmpty()).orElse(false);
}

@Override
public String getHeader(String name)
{
List<String> values = getHeaderValues(name);
return values != null && !values.isEmpty() ? values.get(0) : null;
return getHeaderValues(name).map(values -> values.isEmpty() ? null : values.get(0)).orElse(null);
}

private List<String> getHeaderValues(String name)
private Optional<List<String>> getHeaderValues(String name)
{
for (Map.Entry<String, List<String>> entry : this.response.getHeaders().entrySet()) {
for (Map.Entry<String, List<String>> entry : getHeaders().entrySet()) {
if (StringUtils.equalsIgnoreCase(name, entry.getKey())) {
return entry.getValue();
return Optional.of(entry.getValue());
}
}
return null;
return Optional.empty();
}

@Override
public Collection<String> getHeaders(String name)
{
List<String> values = getHeaderValues(name);
return values != null ? new ArrayList<>(values) : Collections.emptyList();
return getHeaderValues(name).map(ArrayList::new).orElseGet(ArrayList::new);
}

@Override
public Collection<String> getHeaderNames()
{
return new LinkedHashSet<>(this.response.getHeaders().keySet());
return new LinkedHashSet<>(getHeaders().keySet());
}

private Map<String, List<String>> getHeaders()
{
try {
return this.response.getHeaders();
} catch (Exception e) {
LOGGER.debug("Failed to retrieve the WebSocket handshake response headers. "
+ "This can happen if the HandshakeResponse object is used after the handshake is performed, "
+ "e.g. in the WebSocket end-point.", e);
return new HashMap<>();
}
}

@Override
public void setHeader(String name, String value)
{
Set<String> namesToRemove = this.response.getHeaders().keySet().stream()
Set<String> namesToRemove = getHeaders().keySet().stream()
.filter(headerName -> StringUtils.equalsIgnoreCase(name, headerName)).collect(Collectors.toSet());
this.response.getHeaders().keySet().removeAll(namesToRemove);
this.response.getHeaders().put(name, new ArrayList<>(Arrays.asList(value)));
getHeaders().keySet().removeAll(namesToRemove);
getHeaders().put(name, new ArrayList<>(Arrays.asList(value)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -105,4 +107,15 @@ void verifyStub() throws Exception
assertFalse(stub.isUserInRole("developer"));
assertTrue(stub.isUserInRole("tester"));
}

@Test
void staleRequest()
{
HandshakeRequest handshakeRequest = mock(HandshakeRequest.class);
when(handshakeRequest.isUserInRole(anyString())).thenThrow(new RuntimeException("Stale request"));

XWikiWebSocketRequestStub stub = new XWikiWebSocketRequestStub(handshakeRequest);

assertFalse(stub.isUserInRole("admin"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -45,7 +46,7 @@
class XWikiWebSocketResponseStubTest
{
@Test
void verifyStub() throws Exception
void verifyStub()
{
Map<String, List<String>> headers = new LinkedHashMap<>();

Expand Down Expand Up @@ -86,4 +87,15 @@ void verifyStub() throws Exception
assertTrue(stub.containsHeader("dATe"));
assertFalse(stub.containsHeader("Age"));
}

@Test
void staleResponse()
{
HandshakeResponse handshakeResponse = mock(HandshakeResponse.class);
when(handshakeResponse.getHeaders()).thenThrow(new RuntimeException("Stale response"));

XWikiWebSocketResponseStub stub = new XWikiWebSocketResponseStub(handshakeResponse);

assertNull(stub.getHeader("foo"));
}
}

0 comments on commit 924e3c6

Please sign in to comment.