Skip to content

Commit

Permalink
Merge pull request #39 from cesarhernandezgt/tomcat-10.0.x-TT.x-patch…
Browse files Browse the repository at this point in the history
…-wsocket

backport websocket fixes
  • Loading branch information
cesarhernandezgt authored Apr 3, 2024
2 parents d1cd752 + 80e07a7 commit c32e9b9
Show file tree
Hide file tree
Showing 11 changed files with 497 additions and 303 deletions.
6 changes: 6 additions & 0 deletions java/org/apache/tomcat/websocket/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

import jakarta.websocket.Extension;

Expand Down Expand Up @@ -118,6 +119,11 @@ public class Constants {
// Milliseconds so this is 20 seconds
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;

// Configuration for session close timeout
public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
// Default is 30 seconds - setting is in milliseconds
public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);

// Configuration for read idle timeout on WebSocket session
public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";

Expand Down
485 changes: 219 additions & 266 deletions java/org/apache/tomcat/websocket/WsSession.java

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions java/org/apache/tomcat/websocket/WsWebSocketContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,12 @@ Set<Session> getOpenSessions(Object key) {
synchronized (endPointSessionMapLock) {
Set<WsSession> sessions = endpointSessionMap.get(key);
if (sessions != null) {
result.addAll(sessions);
// Some sessions may be in the process of closing
for (WsSession session : sessions) {
if (session.isOpen()) {
result.add(session);
}
}
}
}
return result;
Expand Down Expand Up @@ -1108,12 +1113,14 @@ private AsynchronousChannelGroup getAsynchronousChannelGroup() {
@Override
public void backgroundProcess() {
// This method gets called once a second.
backgroundProcessCount ++;
backgroundProcessCount++;
if (backgroundProcessCount >= processPeriod) {
backgroundProcessCount = 0;

// Check all registered sessions.
for (WsSession wsSession : sessions.keySet()) {
wsSession.checkExpiration();
wsSession.checkCloseTimeout();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,8 @@ protected void registerSession(Object key, WsSession wsSession) {
*/
@Override
protected void unregisterSession(Object key, WsSession wsSession) {
if (wsSession.getUserPrincipal() != null &&
wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession,
wsSession.getHttpSessionId());
if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
}
super.unregisterSession(key, wsSession);
}
Expand Down
10 changes: 10 additions & 0 deletions test/org/apache/catalina/startup/TomcatBaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ public Tomcat getTomcatInstanceTestWebapp(boolean addJstl, boolean start)
return tomcat;
}


public Context getProgrammaticRootContext() {
// No file system docBase required
Context ctx = tomcat.addContext("", null);
// Disable class path scanning - it slows the tests down by almost an order of magnitude
((StandardJarScanner) ctx.getJarScanner()).setScanClassPath(false);
return ctx;
}


/*
* Sub-classes need to know port so they can connect
*/
Expand Down
132 changes: 120 additions & 12 deletions test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.servlet.ServletContextEvent;
import jakarta.servlet.ServletContextListener;
import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.CloseReason;
import jakarta.websocket.ContainerProvider;
Expand All @@ -39,16 +41,21 @@
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
import org.apache.tomcat.websocket.server.WsServerContainer;

public class TestWsSessionSuspendResume extends WebSocketBaseTest {

@Test
public void test() throws Exception {
public void testSuspendResume() throws Exception {
//public void test() throws Exception {
Tomcat tomcat = getTomcatInstance();

Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
//Context ctx = tomcat.addContext("", null);
Context ctx = getProgrammaticRootContext();
//ctx.addApplicationListener(Config.class.getName());
ctx.addApplicationListener(SuspendResumeConfig.class.getName());

Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
Expand All @@ -58,10 +65,12 @@ public void test() throws Exception {
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("ws://localhost:" + getPort() + Config.PATH));
// Session wsSession = wsContainer.connectToServer(
// TesterProgrammaticEndpoint.class,
// clientEndpointConfig,
// new URI("ws://localhost:" + getPort() + Config.PATH));
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));

CountDownLatch latch = new CountDownLatch(2);
wsSession.addMessageHandler(String.class, message -> {
Expand All @@ -79,7 +88,8 @@ public void test() throws Exception {
}


public static final class Config extends TesterEndpointConfig {
//public static final class Config extends TesterEndpointConfig {
public static final class SuspendResumeConfig extends TesterEndpointConfig {
private static final String PATH = "/echo";

@Override
Expand All @@ -97,8 +107,9 @@ protected ServerEndpointConfig getServerEndpointConfig() {
public static final class SuspendResumeEndpoint extends Endpoint {

@Override
public void onOpen(Session session, EndpointConfig epc) {
MessageProcessor processor = new MessageProcessor(session, 3);
public void onOpen(Session session, EndpointConfig epc) {
//MessageProcessor processor = new MessageProcessor(session, 3);
SuspendResumeMessageProcessor processor = new SuspendResumeMessageProcessor(session, 3);
session.addMessageHandler(String.class, message -> processor.addMessage(message));
}

Expand All @@ -118,12 +129,14 @@ public void onError(Session session, Throwable t) {
}


private static final class MessageProcessor {
//private static final class MessageProcessor {
private static final class SuspendResumeMessageProcessor {
private final Session session;
private final int count;
private final List<String> messages = new ArrayList<>();

MessageProcessor(Session session, int count) {
//MessageProcessor(Session session, int count) {
SuspendResumeMessageProcessor(Session session, int count) {
this.session = session;
this.count = count;
}
Expand All @@ -143,4 +156,99 @@ void addMessage(String message) {
}
}
}


@Test
public void testSuspendThenClose() throws Exception {
Tomcat tomcat = getTomcatInstance();

Context ctx = getProgrammaticRootContext();
ctx.addApplicationListener(SuspendCloseConfig.class.getName());
ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());

Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");

tomcat.start();

WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));

wsSession.getBasicRemote().sendText("start test");

// Wait for the client response to be received by the server
int count = 0;
while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) {
Thread.sleep(100);
count ++;
}
Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
}


public static final class SuspendCloseConfig extends TesterEndpointConfig {
private static final String PATH = "/echo";

@Override
protected Class<?> getEndpointClass() {
return SuspendCloseEndpoint.class;
}

@Override
protected ServerEndpointConfig getServerEndpointConfig() {
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
}
}


public static final class SuspendCloseEndpoint extends Endpoint {

// Yes, a static variable is a hack.
private static WsSession serverSession;

@Override
public void onOpen(Session session, EndpointConfig epc) {
serverSession = (WsSession) session;
// Set a short session close timeout (milliseconds)
serverSession.getUserProperties().put(
org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000));
// Any message will trigger the suspend then close
serverSession.addMessageHandler(String.class, message -> {
try {
serverSession.getBasicRemote().sendText("server session open");
serverSession.getBasicRemote().sendText("suspending server session");
serverSession.suspend();
serverSession.getBasicRemote().sendText("closing server session");
serverSession.close();
} catch (IOException ioe) {
ioe.printStackTrace();
// Attempt to make the failure more obvious
throw new RuntimeException(ioe);
}
});
}

@Override
public void onError(Session session, Throwable t) {
t.printStackTrace();
}

public static boolean isServerSessionFullyClosed() {
return serverSession.isClosed();
}
}


public static class WebSocketFastServerTimeout implements ServletContextListener {

@Override
public void contextInitialized(ServletContextEvent sce) {
WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
container.setProcessPeriod(0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,45 +35,41 @@
* significantly extends the length of a test run when using multiple test
* threads.
*/
public class TestWsWebSocketContainerSessionExpiryContainer extends WsWebSocketContainerBaseTest {
public class TestWsWebSocketContainerSessionExpiryContainerClient extends WsWebSocketContainerBaseTest {

@Test
public void testSessionExpiryContainer() throws Exception {

Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
Context ctx = getProgrammaticRootContext();
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");

tomcat.start();

// Need access to implementation methods for configuring unit tests
WsWebSocketContainer wsContainer = (WsWebSocketContainer)
ContainerProvider.getWebSocketContainer();
WsWebSocketContainer wsContainer = (WsWebSocketContainer) ContainerProvider.getWebSocketContainer();

// 5 second timeout
wsContainer.setDefaultMaxSessionIdleTimeout(5000);
wsContainer.setProcessPeriod(1);

EndpointA endpointA = new EndpointA();
connectToEchoServer(wsContainer, endpointA,
TesterEchoServer.Config.PATH_BASIC);
connectToEchoServer(wsContainer, endpointA,
TesterEchoServer.Config.PATH_BASIC);
Session s3a = connectToEchoServer(wsContainer, endpointA,
TesterEchoServer.Config.PATH_BASIC);
connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC);
connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC);
Session s3a = connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC);

// Check all three sessions are open
Set<Session> setA = s3a.getOpenSessions();
Assert.assertEquals(3, setA.size());

int count = 0;
boolean isOpen = true;
while (isOpen && count < 8) {
count ++;
Thread.sleep(1000);
while (isOpen && count < 100) {
count++;
Thread.sleep(100);
isOpen = false;
for (Session session : setA) {
if (session.isOpen()) {
Expand All @@ -86,8 +82,7 @@ public void testSessionExpiryContainer() throws Exception {
if (isOpen) {
for (Session session : setA) {
if (session.isOpen()) {
System.err.println("Session with ID [" + session.getId() +
"] is open");
System.err.println("Session with ID [" + session.getId() + "] is open");
}
}
Assert.fail("There were open sessions");
Expand Down
Loading

0 comments on commit c32e9b9

Please sign in to comment.