Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backport websocket fixes #39

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions java/org/apache/tomcat/websocket/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

import jakarta.websocket.Extension;

Expand Down Expand Up @@ -118,6 +119,11 @@ public class Constants {
// Milliseconds so this is 20 seconds
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;

// Configuration for session close timeout
public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
// Default is 30 seconds - setting is in milliseconds
public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);

// Configuration for read idle timeout on WebSocket session
public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";

Expand Down
485 changes: 219 additions & 266 deletions java/org/apache/tomcat/websocket/WsSession.java

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions java/org/apache/tomcat/websocket/WsWebSocketContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,12 @@ Set<Session> getOpenSessions(Object key) {
synchronized (endPointSessionMapLock) {
Set<WsSession> sessions = endpointSessionMap.get(key);
if (sessions != null) {
result.addAll(sessions);
// Some sessions may be in the process of closing
for (WsSession session : sessions) {
if (session.isOpen()) {
result.add(session);
}
}
}
}
return result;
Expand Down Expand Up @@ -1108,12 +1113,14 @@ private AsynchronousChannelGroup getAsynchronousChannelGroup() {
@Override
public void backgroundProcess() {
// This method gets called once a second.
backgroundProcessCount ++;
backgroundProcessCount++;
if (backgroundProcessCount >= processPeriod) {
backgroundProcessCount = 0;

// Check all registered sessions.
for (WsSession wsSession : sessions.keySet()) {
wsSession.checkExpiration();
wsSession.checkCloseTimeout();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,8 @@ protected void registerSession(Object key, WsSession wsSession) {
*/
@Override
protected void unregisterSession(Object key, WsSession wsSession) {
if (wsSession.getUserPrincipal() != null &&
wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession,
wsSession.getHttpSessionId());
if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) {
unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
}
super.unregisterSession(key, wsSession);
}
Expand Down
10 changes: 10 additions & 0 deletions test/org/apache/catalina/startup/TomcatBaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ public Tomcat getTomcatInstanceTestWebapp(boolean addJstl, boolean start)
return tomcat;
}


public Context getProgrammaticRootContext() {
// No file system docBase required
Context ctx = tomcat.addContext("", null);
// Disable class path scanning - it slows the tests down by almost an order of magnitude
((StandardJarScanner) ctx.getJarScanner()).setScanClassPath(false);
return ctx;
}


/*
* Sub-classes need to know port so they can connect
*/
Expand Down
132 changes: 120 additions & 12 deletions test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.servlet.ServletContextEvent;
import jakarta.servlet.ServletContextListener;
import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.CloseReason;
import jakarta.websocket.ContainerProvider;
Expand All @@ -39,16 +41,21 @@
import org.apache.catalina.servlets.DefaultServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
import org.apache.tomcat.websocket.server.Constants;
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
import org.apache.tomcat.websocket.server.WsServerContainer;

public class TestWsSessionSuspendResume extends WebSocketBaseTest {

@Test
public void test() throws Exception {
public void testSuspendResume() throws Exception {
//public void test() throws Exception {
Tomcat tomcat = getTomcatInstance();

Context ctx = tomcat.addContext("", null);
ctx.addApplicationListener(Config.class.getName());
//Context ctx = tomcat.addContext("", null);
Context ctx = getProgrammaticRootContext();
//ctx.addApplicationListener(Config.class.getName());
ctx.addApplicationListener(SuspendResumeConfig.class.getName());

Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");
Expand All @@ -58,10 +65,12 @@ public void test() throws Exception {
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(
TesterProgrammaticEndpoint.class,
clientEndpointConfig,
new URI("ws://localhost:" + getPort() + Config.PATH));
// Session wsSession = wsContainer.connectToServer(
// TesterProgrammaticEndpoint.class,
// clientEndpointConfig,
// new URI("ws://localhost:" + getPort() + Config.PATH));
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));

CountDownLatch latch = new CountDownLatch(2);
wsSession.addMessageHandler(String.class, message -> {
Expand All @@ -79,7 +88,8 @@ public void test() throws Exception {
}


public static final class Config extends TesterEndpointConfig {
//public static final class Config extends TesterEndpointConfig {
public static final class SuspendResumeConfig extends TesterEndpointConfig {
private static final String PATH = "/echo";

@Override
Expand All @@ -97,8 +107,9 @@ protected ServerEndpointConfig getServerEndpointConfig() {
public static final class SuspendResumeEndpoint extends Endpoint {

@Override
public void onOpen(Session session, EndpointConfig epc) {
MessageProcessor processor = new MessageProcessor(session, 3);
public void onOpen(Session session, EndpointConfig epc) {
//MessageProcessor processor = new MessageProcessor(session, 3);
SuspendResumeMessageProcessor processor = new SuspendResumeMessageProcessor(session, 3);
session.addMessageHandler(String.class, message -> processor.addMessage(message));
}

Expand All @@ -118,12 +129,14 @@ public void onError(Session session, Throwable t) {
}


private static final class MessageProcessor {
//private static final class MessageProcessor {
private static final class SuspendResumeMessageProcessor {
private final Session session;
private final int count;
private final List<String> messages = new ArrayList<>();

MessageProcessor(Session session, int count) {
//MessageProcessor(Session session, int count) {
SuspendResumeMessageProcessor(Session session, int count) {
this.session = session;
this.count = count;
}
Expand All @@ -143,4 +156,99 @@ void addMessage(String message) {
}
}
}


@Test
public void testSuspendThenClose() throws Exception {
Tomcat tomcat = getTomcatInstance();

Context ctx = getProgrammaticRootContext();
ctx.addApplicationListener(SuspendCloseConfig.class.getName());
ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());

Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");

tomcat.start();

WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();

ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));

wsSession.getBasicRemote().sendText("start test");

// Wait for the client response to be received by the server
int count = 0;
while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) {
Thread.sleep(100);
count ++;
}
Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
}


public static final class SuspendCloseConfig extends TesterEndpointConfig {
private static final String PATH = "/echo";

@Override
protected Class<?> getEndpointClass() {
return SuspendCloseEndpoint.class;
}

@Override
protected ServerEndpointConfig getServerEndpointConfig() {
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
}
}


public static final class SuspendCloseEndpoint extends Endpoint {

// Yes, a static variable is a hack.
private static WsSession serverSession;

@Override
public void onOpen(Session session, EndpointConfig epc) {
serverSession = (WsSession) session;
// Set a short session close timeout (milliseconds)
serverSession.getUserProperties().put(
org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000));
// Any message will trigger the suspend then close
serverSession.addMessageHandler(String.class, message -> {
try {
serverSession.getBasicRemote().sendText("server session open");
serverSession.getBasicRemote().sendText("suspending server session");
serverSession.suspend();
serverSession.getBasicRemote().sendText("closing server session");
serverSession.close();
} catch (IOException ioe) {
ioe.printStackTrace();
// Attempt to make the failure more obvious
throw new RuntimeException(ioe);
}
});
}

@Override
public void onError(Session session, Throwable t) {
t.printStackTrace();
}

public static boolean isServerSessionFullyClosed() {
return serverSession.isClosed();
}
}


public static class WebSocketFastServerTimeout implements ServletContextListener {

@Override
public void contextInitialized(ServletContextEvent sce) {
WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute(
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
container.setProcessPeriod(0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,45 +35,41 @@
* significantly extends the length of a test run when using multiple test
* threads.
*/
public class TestWsWebSocketContainerSessionExpiryContainer extends WsWebSocketContainerBaseTest {
public class TestWsWebSocketContainerSessionExpiryContainerClient extends WsWebSocketContainerBaseTest {

@Test
public void testSessionExpiryContainer() throws Exception {

Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
Context ctx = getProgrammaticRootContext();
ctx.addApplicationListener(TesterEchoServer.Config.class.getName());
Tomcat.addServlet(ctx, "default", new DefaultServlet());
ctx.addServletMappingDecoded("/", "default");

tomcat.start();

// Need access to implementation methods for configuring unit tests
WsWebSocketContainer wsContainer = (WsWebSocketContainer)
ContainerProvider.getWebSocketContainer();
WsWebSocketContainer wsContainer = (WsWebSocketContainer) ContainerProvider.getWebSocketContainer();

// 5 second timeout
wsContainer.setDefaultMaxSessionIdleTimeout(5000);
wsContainer.setProcessPeriod(1);

EndpointA endpointA = new EndpointA();
connectToEchoServer(wsContainer, endpointA,
TesterEchoServer.Config.PATH_BASIC);
connectToEchoServer(wsContainer, endpointA,
TesterEchoServer.Config.PATH_BASIC);
Session s3a = connectToEchoServer(wsContainer, endpointA,
TesterEchoServer.Config.PATH_BASIC);
connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC);
connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC);
Session s3a = connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC);

// Check all three sessions are open
Set<Session> setA = s3a.getOpenSessions();
Assert.assertEquals(3, setA.size());

int count = 0;
boolean isOpen = true;
while (isOpen && count < 8) {
count ++;
Thread.sleep(1000);
while (isOpen && count < 100) {
count++;
Thread.sleep(100);
isOpen = false;
for (Session session : setA) {
if (session.isOpen()) {
Expand All @@ -86,8 +82,7 @@ public void testSessionExpiryContainer() throws Exception {
if (isOpen) {
for (Session session : setA) {
if (session.isOpen()) {
System.err.println("Session with ID [" + session.getId() +
"] is open");
System.err.println("Session with ID [" + session.getId() + "] is open");
}
}
Assert.fail("There were open sessions");
Expand Down
Loading
Loading