diff --git a/java/org/apache/tomcat/websocket/Constants.java b/java/org/apache/tomcat/websocket/Constants.java index 8dffa303996e..3bb4e0057899 100644 --- a/java/org/apache/tomcat/websocket/Constants.java +++ b/java/org/apache/tomcat/websocket/Constants.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import jakarta.websocket.Extension; @@ -118,6 +119,11 @@ public class Constants { // Milliseconds so this is 20 seconds public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + // Configuration for session close timeout + public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT"; + // Default is 30 seconds - setting is in milliseconds + public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + // Configuration for read idle timeout on WebSocket session public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS"; diff --git a/java/org/apache/tomcat/websocket/WsSession.java b/java/org/apache/tomcat/websocket/WsSession.java index 8766d46b1c1d..0c1b2d18dd92 100644 --- a/java/org/apache/tomcat/websocket/WsSession.java +++ b/java/org/apache/tomcat/websocket/WsSession.java @@ -27,7 +27,9 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import javax.naming.NamingException; @@ -78,8 +80,8 @@ public class WsSession implements Session { // be sufficient to pass the validation tests. ServerEndpointConfig.Builder builder = ServerEndpointConfig.Builder.create(Object.class, "/"); ServerEndpointConfig sec = builder.build(); - SEC_CONFIGURATOR_USES_IMPL_DEFAULT = - sec.getConfigurator().getClass().equals(DefaultServerEndpointConfigurator.class); + SEC_CONFIGURATOR_USES_IMPL_DEFAULT = sec.getConfigurator().getClass() + .equals(DefaultServerEndpointConfigurator.class); } private final Endpoint localEndpoint; @@ -106,8 +108,7 @@ public class WsSession implements Session { // Expected to handle message types of only private volatile MessageHandler binaryMessageHandler = null; private volatile MessageHandler.Whole pongMessageHandler = null; - private volatile State state = State.OPEN; - private final Object stateLock = new Object(); + private AtomicReference state = new AtomicReference<>(State.OPEN); private final Map userProperties = new ConcurrentHashMap<>(); private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; @@ -115,35 +116,30 @@ public class WsSession implements Session { private volatile long lastActiveRead = System.currentTimeMillis(); private volatile long lastActiveWrite = System.currentTimeMillis(); private Map futures = new ConcurrentHashMap<>(); + private volatile Long sessionCloseTimeoutExpiry; /** - * Creates a new WebSocket session for communication between the provided - * client and remote end points. The result of - * {@link Thread#getContextClassLoader()} at the time this constructor is - * called will be used when calling + * Creates a new WebSocket session for communication between the provided client and remote end points. The result + * of {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling * {@link Endpoint#onClose(Session, CloseReason)}. * * @param clientEndpointHolder The end point managed by this code * @param wsRemoteEndpoint The other / remote end point * @param wsWebSocketContainer The container that created this session * @param negotiatedExtensions The agreed extensions to use for this session - * @param subProtocol The agreed sub-protocol to use for this - * session - * @param pathParameters The path parameters associated with the - * request that initiated this session or - * null if this is a client session - * @param secure Was this session initiated over a secure - * connection? - * @param clientEndpointConfig The configuration information for the client - * end point + * @param subProtocol The agreed sub-protocol to use for this session + * @param pathParameters The path parameters associated with the request that initiated this session or + * null if this is a client session + * @param secure Was this session initiated over a secure connection? + * @param clientEndpointConfig The configuration information for the client end point + * * @throws DeploymentException if an invalid encode is specified */ - public WsSession(ClientEndpointHolder clientEndpointHolder, - WsRemoteEndpointImplBase wsRemoteEndpoint, - WsWebSocketContainer wsWebSocketContainer, - List negotiatedExtensions, String subProtocol, Map pathParameters, - boolean secure, ClientEndpointConfig clientEndpointConfig) throws DeploymentException { + public WsSession(ClientEndpointHolder clientEndpointHolder, WsRemoteEndpointImplBase wsRemoteEndpoint, + WsWebSocketContainer wsWebSocketContainer, List negotiatedExtensions, String subProtocol, + Map pathParameters, boolean secure, ClientEndpointConfig clientEndpointConfig) + throws DeploymentException { this.wsRemoteEndpoint = wsRemoteEndpoint; this.wsRemoteEndpoint.setSession(this); this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint); @@ -175,53 +171,43 @@ public WsSession(ClientEndpointHolder clientEndpointHolder, this.localEndpoint = clientEndpointHolder.getInstance(getInstanceManager()); - if (log.isDebugEnabled()) { - log.debug(sm.getString("wsSession.created", id)); + if (log.isTraceEnabled()) { + log.trace(sm.getString("wsSession.created", id)); } } /** - * Creates a new WebSocket session for communication between the provided - * server and remote end points. The result of - * {@link Thread#getContextClassLoader()} at the time this constructor is - * called will be used when calling + * Creates a new WebSocket session for communication between the provided server and remote end points. The result + * of {@link Thread#getContextClassLoader()} at the time this constructor is called will be used when calling * {@link Endpoint#onClose(Session, CloseReason)}. * * @param wsRemoteEndpoint The other / remote end point * @param wsWebSocketContainer The container that created this session - * @param requestUri The URI used to connect to this end point or - * null if this is a client session - * @param requestParameterMap The parameters associated with the request - * that initiated this session or - * null if this is a client session - * @param queryString The query string associated with the request - * that initiated this session or - * null if this is a client session - * @param userPrincipal The principal associated with the request - * that initiated this session or - * null if this is a client session - * @param httpSessionId The HTTP session ID associated with the - * request that initiated this session or - * null if this is a client session + * @param requestUri The URI used to connect to this end point or null if this is a client + * session + * @param requestParameterMap The parameters associated with the request that initiated this session or + * null if this is a client session + * @param queryString The query string associated with the request that initiated this session or + * null if this is a client session + * @param userPrincipal The principal associated with the request that initiated this session or + * null if this is a client session + * @param httpSessionId The HTTP session ID associated with the request that initiated this session or + * null if this is a client session * @param negotiatedExtensions The agreed extensions to use for this session - * @param subProtocol The agreed sub-protocol to use for this - * session - * @param pathParameters The path parameters associated with the - * request that initiated this session or - * null if this is a client session - * @param secure Was this session initiated over a secure - * connection? - * @param serverEndpointConfig The configuration information for the server - * end point + * @param subProtocol The agreed sub-protocol to use for this session + * @param pathParameters The path parameters associated with the request that initiated this session or + * null if this is a client session + * @param secure Was this session initiated over a secure connection? + * @param serverEndpointConfig The configuration information for the server end point + * * @throws DeploymentException if an invalid encode is specified */ - public WsSession(WsRemoteEndpointImplBase wsRemoteEndpoint, - WsWebSocketContainer wsWebSocketContainer, - URI requestUri, Map> requestParameterMap, - String queryString, Principal userPrincipal, String httpSessionId, - List negotiatedExtensions, String subProtocol, Map pathParameters, - boolean secure, ServerEndpointConfig serverEndpointConfig) throws DeploymentException { + public WsSession(WsRemoteEndpointImplBase wsRemoteEndpoint, WsWebSocketContainer wsWebSocketContainer, + URI requestUri, Map> requestParameterMap, String queryString, Principal userPrincipal, + String httpSessionId, List negotiatedExtensions, String subProtocol, + Map pathParameters, boolean secure, ServerEndpointConfig serverEndpointConfig) + throws DeploymentException { this.wsRemoteEndpoint = wsRemoteEndpoint; this.wsRemoteEndpoint.setSession(this); @@ -284,8 +270,8 @@ public WsSession(WsRemoteEndpointImplBase wsRemoteEndpoint, this.localEndpoint = new PojoEndpointServer(pathParameters, endpointInstance); } - if (log.isDebugEnabled()) { - log.debug(sm.getString("wsSession.created", id)); + if (log.isTraceEnabled()) { + log.trace(sm.getString("wsSession.created", id)); } } @@ -302,100 +288,6 @@ private boolean isDefaultConfigurator(Configurator configurator) { } - /** - * Creates a new WebSocket session for communication between the two - * provided end points. The result of {@link Thread#getContextClassLoader()} - * at the time this constructor is called will be used when calling - * {@link Endpoint#onClose(Session, CloseReason)}. - * - * @param localEndpoint The end point managed by this code - * @param wsRemoteEndpoint The other / remote endpoint - * @param wsWebSocketContainer The container that created this session - * @param requestUri The URI used to connect to this endpoint or - * null is this is a client session - * @param requestParameterMap The parameters associated with the request - * that initiated this session or - * null if this is a client session - * @param queryString The query string associated with the request - * that initiated this session or - * null if this is a client session - * @param userPrincipal The principal associated with the request - * that initiated this session or - * null if this is a client session - * @param httpSessionId The HTTP session ID associated with the - * request that initiated this session or - * null if this is a client session - * @param negotiatedExtensions The agreed extensions to use for this session - * @param subProtocol The agreed subprotocol to use for this - * session - * @param pathParameters The path parameters associated with the - * request that initiated this session or - * null if this is a client session - * @param secure Was this session initiated over a secure - * connection? - * @param endpointConfig The configuration information for the - * endpoint - * @throws DeploymentException if an invalid encode is specified - * - * @deprecated Unused. This will be removed in Tomcat 10.1 - */ - @Deprecated - public WsSession(Endpoint localEndpoint, - WsRemoteEndpointImplBase wsRemoteEndpoint, - WsWebSocketContainer wsWebSocketContainer, - URI requestUri, Map> requestParameterMap, - String queryString, Principal userPrincipal, String httpSessionId, - List negotiatedExtensions, String subProtocol, Map pathParameters, - boolean secure, EndpointConfig endpointConfig) throws DeploymentException { - this.localEndpoint = localEndpoint; - this.wsRemoteEndpoint = wsRemoteEndpoint; - this.wsRemoteEndpoint.setSession(this); - this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint); - this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint); - this.webSocketContainer = wsWebSocketContainer; - applicationClassLoader = Thread.currentThread().getContextClassLoader(); - wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout()); - this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize(); - this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize(); - this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout(); - this.requestUri = requestUri; - if (requestParameterMap == null) { - this.requestParameterMap = Collections.emptyMap(); - } else { - this.requestParameterMap = requestParameterMap; - } - this.queryString = queryString; - this.userPrincipal = userPrincipal; - this.httpSessionId = httpSessionId; - this.negotiatedExtensions = negotiatedExtensions; - if (subProtocol == null) { - this.subProtocol = ""; - } else { - this.subProtocol = subProtocol; - } - this.pathParameters = pathParameters; - this.secure = secure; - this.wsRemoteEndpoint.setEncoders(endpointConfig); - this.endpointConfig = endpointConfig; - - this.userProperties.putAll(endpointConfig.getUserProperties()); - this.id = Long.toHexString(ids.getAndIncrement()); - - InstanceManager instanceManager = getInstanceManager(); - if (instanceManager != null) { - try { - instanceManager.newInstance(localEndpoint); - } catch (Exception e) { - throw new DeploymentException(sm.getString("wsSession.instanceNew"), e); - } - } - - if (log.isDebugEnabled()) { - log.debug(sm.getString("wsSession.created", id)); - } - } - - public InstanceManager getInstanceManager() { return webSocketContainer.getInstanceManager(applicationClassLoader); } @@ -416,15 +308,13 @@ public void addMessageHandler(MessageHandler listener) { @Override - public void addMessageHandler(Class clazz, Partial handler) - throws IllegalStateException { + public void addMessageHandler(Class clazz, Partial handler) throws IllegalStateException { doAddMessageHandler(clazz, handler); } @Override - public void addMessageHandler(Class clazz, Whole handler) - throws IllegalStateException { + public void addMessageHandler(Class clazz, Whole handler) throws IllegalStateException { doAddMessageHandler(clazz, handler); } @@ -443,44 +333,41 @@ private void doAddMessageHandler(Class target, MessageHandler listener) { // arbitrary objects with MessageHandlers and can wrap MessageHandlers // just as easily. - Set mhResults = Util.getMessageHandlers(target, listener, - endpointConfig, this); + Set mhResults = Util.getMessageHandlers(target, listener, endpointConfig, this); for (MessageHandlerResult mhResult : mhResults) { switch (mhResult.getType()) { - case TEXT: { - if (textMessageHandler != null) { - throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText")); + case TEXT: { + if (textMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText")); + } + textMessageHandler = mhResult.getHandler(); + break; } - textMessageHandler = mhResult.getHandler(); - break; - } - case BINARY: { - if (binaryMessageHandler != null) { - throw new IllegalStateException( - sm.getString("wsSession.duplicateHandlerBinary")); + case BINARY: { + if (binaryMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerBinary")); + } + binaryMessageHandler = mhResult.getHandler(); + break; } - binaryMessageHandler = mhResult.getHandler(); - break; - } - case PONG: { - if (pongMessageHandler != null) { - throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong")); + case PONG: { + if (pongMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong")); + } + MessageHandler handler = mhResult.getHandler(); + if (handler instanceof MessageHandler.Whole) { + pongMessageHandler = (MessageHandler.Whole) handler; + } else { + throw new IllegalStateException(sm.getString("wsSession.invalidHandlerTypePong")); + } + + break; } - MessageHandler handler = mhResult.getHandler(); - if (handler instanceof MessageHandler.Whole) { - pongMessageHandler = (MessageHandler.Whole) handler; - } else { - throw new IllegalStateException( - sm.getString("wsSession.invalidHandlerTypePong")); + default: { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType())); } - - break; - } - default: { - throw new IllegalArgumentException( - sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType())); - } } } } @@ -539,8 +426,7 @@ public void removeMessageHandler(MessageHandler listener) { if (!removed) { // ISE for now. Could swallow this silently / log this if the ISE // becomes a problem - throw new IllegalStateException( - sm.getString("wsSession.removeHandlerFailed", listener)); + throw new IllegalStateException(sm.getString("wsSession.removeHandlerFailed", listener)); } } @@ -575,7 +461,12 @@ public boolean isSecure() { @Override public boolean isOpen() { - return state == State.OPEN; + return state.get() == State.OPEN || state.get() == State.OUTPUT_CLOSING || state.get() == State.CLOSING; + } + + + public boolean isClosed() { + return state.get() == State.CLOSED; } @@ -655,9 +546,8 @@ public void close(CloseReason closeReason) throws IOException { /** - * WebSocket 1.0. Section 2.1.5. - * Need internal close method as spec requires that the local endpoint - * receives a 1006 on timeout. + * WebSocket 1.0. Section 2.1.5. Need internal close method as spec requires that the local endpoint receives a 1006 + * on timeout. * * @param closeReasonMessage The close reason to pass to the remote endpoint * @param closeReasonLocal The close reason to pass to the local endpoint @@ -668,55 +558,54 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal /** - * WebSocket 1.0. Section 2.1.5. - * Need internal close method as spec requires that the local endpoint - * receives a 1006 on timeout. + * WebSocket 1.0. Section 2.1.5. Need internal close method as spec requires that the local endpoint receives a 1006 + * on timeout. * * @param closeReasonMessage The close reason to pass to the remote endpoint * @param closeReasonLocal The close reason to pass to the local endpoint - * @param closeSocket Should the socket be closed immediately rather than waiting - * for the server to respond + * @param closeSocket Should the socket be closed immediately rather than waiting for the server to respond */ - public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal, - boolean closeSocket) { - // Double-checked locking. OK because state is volatile - if (state != State.OPEN) { + public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal, boolean closeSocket) { + + if (!state.compareAndSet(State.OPEN, State.OUTPUT_CLOSING)) { + // Close process has already been started. Don't start it again. return; } - synchronized (stateLock) { - if (state != State.OPEN) { - return; - } - - if (log.isDebugEnabled()) { - log.debug(sm.getString("wsSession.doClose", id)); - } + if (log.isTraceEnabled()) { + log.trace(sm.getString("wsSession.doClose", id)); + } - // This will trigger a flush of any batched messages. - try { - wsRemoteEndpoint.setBatchingAllowed(false); - } catch (IOException e) { - log.warn(sm.getString("wsSession.flushFailOnClose"), e); - fireEndpointOnError(e); - } + // Flush any batched messages not yet sent. + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.warn(sm.getString("wsSession.flushFailOnClose"), t); + fireEndpointOnError(t); + } + // Send the close message to the remote endpoint. + sendCloseMessage(closeReasonMessage); + fireEndpointOnClose(closeReasonLocal); + if (!state.compareAndSet(State.OUTPUT_CLOSING, State.OUTPUT_CLOSED) || closeSocket) { /* - * If the flush above fails the error handling could call this - * method recursively. Without this check, the close message and - * notifications could be sent multiple times. + * A close message was received in another thread or this is handling an error condition. Either way, no + * further close message is expected to be received. Mark the session as fully closed... */ - if (state != State.OUTPUT_CLOSED) { - state = State.OUTPUT_CLOSED; - - sendCloseMessage(closeReasonMessage); - if (closeSocket) { - wsRemoteEndpoint.close(); - } - fireEndpointOnClose(closeReasonLocal); - } + state.set(State.CLOSED); + // ... and close the network connection. + closeConnection(); + } else { + /* + * Set close timeout. If the client fails to send a close message response within the timeout, the session + * and the connection will be closed when the timeout expires. + */ + sessionCloseTimeoutExpiry = + Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout())); } + // Fail any uncompleted messages. IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); SendResult sr = new SendResult(ioe); for (FutureToSendHandler f2sh : futures.keySet()) { @@ -726,36 +615,93 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal /** - * Called when a close message is received. Should only ever happen once. - * Also called after a protocol error when the ProtocolHandler needs to - * force the closing of the connection. + * Called when a close message is received. Should only ever happen once. Also called after a protocol error when + * the ProtocolHandler needs to force the closing of the connection. * - * @param closeReason The reason contained within the received close - * message. + * @param closeReason The reason contained within the received close message. */ public void onClose(CloseReason closeReason) { + if (state.compareAndSet(State.OPEN, State.CLOSING)) { + // Standard close. - synchronized (stateLock) { - if (state != State.CLOSED) { - try { - wsRemoteEndpoint.setBatchingAllowed(false); - } catch (IOException e) { - log.warn(sm.getString("wsSession.flushFailOnClose"), e); - fireEndpointOnError(e); - } - if (state == State.OPEN) { - state = State.OUTPUT_CLOSED; - sendCloseMessage(closeReason); - fireEndpointOnClose(closeReason); - } - state = State.CLOSED; + // Flush any batched messages not yet sent. + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.warn(sm.getString("wsSession.flushFailOnClose"), t); + fireEndpointOnError(t); + } + + // Send the close message response to the remote endpoint. + sendCloseMessage(closeReason); + fireEndpointOnClose(closeReason); + + // Mark the session as fully closed. + state.set(State.CLOSED); - // Close the socket - wsRemoteEndpoint.close(); + // Close the network connection. + closeConnection(); + } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) { + /* + * The local endpoint sent a close message the the same time as the remote endpoint. The local close is + * still being processed. Update the state so the the local close process will also close the network + * connection once it has finished sending a close message. + */ + } else if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) { + /* + * The local endpoint sent the first close message. The remote endpoint has now responded with its own close + * message so mark the session as fully closed and close the network connection. + */ + closeConnection(); + } + // CLOSING and CLOSED are NO-OPs + } + + + private void closeConnection() { + /* + * Close the network connection. + */ + wsRemoteEndpoint.close(); + /* + * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for + * tracking the session close timeout. + */ + webSocketContainer.unregisterSession(getSessionMapKey(), this); + } + + + /* + * Returns the session close timeout in milliseconds + */ + protected long getSessionCloseTimeout() { + long result = 0; + Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY); + if (obj instanceof Long) { + result = ((Long) obj).intValue(); + } + if (result <= 0) { + result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT; + } + return result; + } + + + protected void checkCloseTimeout() { + // Skip the check if no session close timeout has been set. + if (sessionCloseTimeoutExpiry != null) { + // Check if the timeout has expired. + if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) { + // Check if the session has been closed in another thread while the timeout was being processed. + if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) { + closeConnection(); + } } } } + private void fireEndpointOnClose(CloseReason closeReason) { // Fire the onClose event @@ -789,7 +735,6 @@ private void fireEndpointOnClose(CloseReason closeReason) { } - private void fireEndpointOnError(Throwable throwable) { // Fire the onError event @@ -829,7 +774,7 @@ private void sendCloseMessage(CloseReason closeReason) { if (log.isDebugEnabled()) { log.debug(sm.getString("wsSession.sendCloseFail", id), e); } - wsRemoteEndpoint.close(); + closeConnection(); // Failure to send a close message is not unexpected in the case of // an abnormal closure (usually triggered by a failure to read/write // from/to the client. In this case do not trigger the endpoint's @@ -837,8 +782,6 @@ private void sendCloseMessage(CloseReason closeReason) { if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { localEndpoint.onError(this, e); } - } finally { - webSocketContainer.unregisterSession(getSessionMapKey(), this); } } @@ -855,7 +798,8 @@ private Object getSessionMapKey() { /** * Use protected so unit tests can access this method directly. - * @param msg The message + * + * @param msg The message * @param reason The reason */ protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) { @@ -885,9 +829,9 @@ protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String rea /** - * Make the session aware of a {@link FutureToSendHandler} that will need to - * be forcibly closed if the session closes before the - * {@link FutureToSendHandler} completes. + * Make the session aware of a {@link FutureToSendHandler} that will need to be forcibly closed if the session + * closes before the {@link FutureToSendHandler} completes. + * * @param f2sh The handler */ protected void registerFuture(FutureToSendHandler f2sh) { @@ -900,13 +844,13 @@ protected void registerFuture(FutureToSendHandler f2sh) { // Always register the future. futures.put(f2sh, f2sh); - if (state == State.OPEN) { + if (isOpen()) { // The session is open. The future has been registered with the open // session. Normal processing continues. return; } - // The session is closed. The future may or may not have been registered + // The session is closing / closed. The future may or may not have been registered // in time for it to be processed during session closure. if (f2sh.isDone()) { @@ -916,7 +860,7 @@ protected void registerFuture(FutureToSendHandler f2sh) { return; } - // The session is closed. The Future had not completed when last checked. + // The session is closing / closed. The Future had not completed when last checked. // There is a small timing window that means the Future may have been // completed since the last check. There is also the possibility that // the Future was not registered in time to be cleaned up during session @@ -928,7 +872,7 @@ protected void registerFuture(FutureToSendHandler f2sh) { // complete the Future but knowing if this is the case requires the sync // on stateLock (see above). // Note: If multiple attempts are made to complete the Future, the - // second and subsequent attempts are ignored. + // second and subsequent attempts are ignored. IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); SendResult sr = new SendResult(ioe); @@ -938,6 +882,7 @@ protected void registerFuture(FutureToSendHandler f2sh) { /** * Remove a {@link FutureToSendHandler} from the set of tracked instances. + * * @param f2sh The handler */ protected void unregisterFuture(FutureToSendHandler f2sh) { @@ -969,6 +914,11 @@ public String getQueryString() { @Override public Principal getUserPrincipal() { checkState(); + return getUserPrincipalInternal(); + } + + + public Principal getUserPrincipalInternal() { return userPrincipal; } @@ -1075,10 +1025,10 @@ private long getMaxIdleTimeoutWrite() { private void checkState() { - if (state == State.CLOSED) { + if (isClosed()) { /* - * As per RFC 6455, a WebSocket connection is considered to be - * closed once a peer has sent and received a WebSocket close frame. + * As per RFC 6455, a WebSocket connection is considered to be closed once a peer has sent and received a + * WebSocket close frame. */ throw new IllegalStateException(sm.getString("wsSession.closed", id)); } @@ -1086,12 +1036,15 @@ private void checkState() { private enum State { OPEN, + OUTPUT_CLOSING, OUTPUT_CLOSED, + CLOSING, CLOSED } private WsFrameBase wsFrame; + void setWsFrame(WsFrameBase wsFrame) { this.wsFrame = wsFrame; } diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java index 509fabe97910..b7cf468258c5 100644 --- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java +++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java @@ -646,7 +646,12 @@ Set getOpenSessions(Object key) { synchronized (endPointSessionMapLock) { Set sessions = endpointSessionMap.get(key); if (sessions != null) { - result.addAll(sessions); + // Some sessions may be in the process of closing + for (WsSession session : sessions) { + if (session.isOpen()) { + result.add(session); + } + } } } return result; @@ -1108,12 +1113,14 @@ private AsynchronousChannelGroup getAsynchronousChannelGroup() { @Override public void backgroundProcess() { // This method gets called once a second. - backgroundProcessCount ++; + backgroundProcessCount++; if (backgroundProcessCount >= processPeriod) { backgroundProcessCount = 0; + // Check all registered sessions. for (WsSession wsSession : sessions.keySet()) { wsSession.checkExpiration(); + wsSession.checkCloseTimeout(); } } diff --git a/java/org/apache/tomcat/websocket/server/WsServerContainer.java b/java/org/apache/tomcat/websocket/server/WsServerContainer.java index ba78a84f857c..134980ab318d 100644 --- a/java/org/apache/tomcat/websocket/server/WsServerContainer.java +++ b/java/org/apache/tomcat/websocket/server/WsServerContainer.java @@ -475,10 +475,8 @@ protected void registerSession(Object key, WsSession wsSession) { */ @Override protected void unregisterSession(Object key, WsSession wsSession) { - if (wsSession.getUserPrincipal() != null && - wsSession.getHttpSessionId() != null) { - unregisterAuthenticatedSession(wsSession, - wsSession.getHttpSessionId()); + if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) { + unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId()); } super.unregisterSession(key, wsSession); } diff --git a/test/org/apache/catalina/startup/TomcatBaseTest.java b/test/org/apache/catalina/startup/TomcatBaseTest.java index 2ab8abd43031..f72300a41852 100644 --- a/test/org/apache/catalina/startup/TomcatBaseTest.java +++ b/test/org/apache/catalina/startup/TomcatBaseTest.java @@ -138,6 +138,16 @@ public Tomcat getTomcatInstanceTestWebapp(boolean addJstl, boolean start) return tomcat; } + + public Context getProgrammaticRootContext() { + // No file system docBase required + Context ctx = tomcat.addContext("", null); + // Disable class path scanning - it slows the tests down by almost an order of magnitude + ((StandardJarScanner) ctx.getJarScanner()).setScanClassPath(false); + return ctx; + } + + /* * Sub-classes need to know port so they can connect */ diff --git a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java index b17dddf54f55..484e1a18dbb2 100644 --- a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java +++ b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java @@ -23,6 +23,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import jakarta.servlet.ServletContextEvent; +import jakarta.servlet.ServletContextListener; import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.CloseReason; import jakarta.websocket.ContainerProvider; @@ -39,16 +41,21 @@ import org.apache.catalina.servlets.DefaultServlet; import org.apache.catalina.startup.Tomcat; import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint; +import org.apache.tomcat.websocket.server.Constants; import org.apache.tomcat.websocket.server.TesterEndpointConfig; +import org.apache.tomcat.websocket.server.WsServerContainer; public class TestWsSessionSuspendResume extends WebSocketBaseTest { @Test - public void test() throws Exception { + public void testSuspendResume() throws Exception { + //public void test() throws Exception { Tomcat tomcat = getTomcatInstance(); - Context ctx = tomcat.addContext("", null); - ctx.addApplicationListener(Config.class.getName()); + //Context ctx = tomcat.addContext("", null); + Context ctx = getProgrammaticRootContext(); + //ctx.addApplicationListener(Config.class.getName()); + ctx.addApplicationListener(SuspendResumeConfig.class.getName()); Tomcat.addServlet(ctx, "default", new DefaultServlet()); ctx.addServletMappingDecoded("/", "default"); @@ -58,10 +65,12 @@ public void test() throws Exception { WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build(); - Session wsSession = wsContainer.connectToServer( - TesterProgrammaticEndpoint.class, - clientEndpointConfig, - new URI("ws://localhost:" + getPort() + Config.PATH)); +// Session wsSession = wsContainer.connectToServer( +// TesterProgrammaticEndpoint.class, +// clientEndpointConfig, +// new URI("ws://localhost:" + getPort() + Config.PATH)); + Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig, + new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH)); CountDownLatch latch = new CountDownLatch(2); wsSession.addMessageHandler(String.class, message -> { @@ -79,7 +88,8 @@ public void test() throws Exception { } - public static final class Config extends TesterEndpointConfig { + //public static final class Config extends TesterEndpointConfig { + public static final class SuspendResumeConfig extends TesterEndpointConfig { private static final String PATH = "/echo"; @Override @@ -97,8 +107,9 @@ protected ServerEndpointConfig getServerEndpointConfig() { public static final class SuspendResumeEndpoint extends Endpoint { @Override - public void onOpen(Session session, EndpointConfig epc) { - MessageProcessor processor = new MessageProcessor(session, 3); + public void onOpen(Session session, EndpointConfig epc) { + //MessageProcessor processor = new MessageProcessor(session, 3); + SuspendResumeMessageProcessor processor = new SuspendResumeMessageProcessor(session, 3); session.addMessageHandler(String.class, message -> processor.addMessage(message)); } @@ -118,12 +129,14 @@ public void onError(Session session, Throwable t) { } - private static final class MessageProcessor { + //private static final class MessageProcessor { + private static final class SuspendResumeMessageProcessor { private final Session session; private final int count; private final List messages = new ArrayList<>(); - MessageProcessor(Session session, int count) { + //MessageProcessor(Session session, int count) { + SuspendResumeMessageProcessor(Session session, int count) { this.session = session; this.count = count; } @@ -143,4 +156,99 @@ void addMessage(String message) { } } } + + + @Test + public void testSuspendThenClose() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + ctx.addApplicationListener(SuspendCloseConfig.class.getName()); + ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName()); + + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + tomcat.start(); + + WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); + + ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build(); + Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig, + new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH)); + + wsSession.getBasicRemote().sendText("start test"); + + // Wait for the client response to be received by the server + int count = 0; + while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) { + Thread.sleep(100); + count ++; + } + Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed()); + } + + + public static final class SuspendCloseConfig extends TesterEndpointConfig { + private static final String PATH = "/echo"; + + @Override + protected Class getEndpointClass() { + return SuspendCloseEndpoint.class; + } + + @Override + protected ServerEndpointConfig getServerEndpointConfig() { + return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build(); + } + } + + + public static final class SuspendCloseEndpoint extends Endpoint { + + // Yes, a static variable is a hack. + private static WsSession serverSession; + + @Override + public void onOpen(Session session, EndpointConfig epc) { + serverSession = (WsSession) session; + // Set a short session close timeout (milliseconds) + serverSession.getUserProperties().put( + org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000)); + // Any message will trigger the suspend then close + serverSession.addMessageHandler(String.class, message -> { + try { + serverSession.getBasicRemote().sendText("server session open"); + serverSession.getBasicRemote().sendText("suspending server session"); + serverSession.suspend(); + serverSession.getBasicRemote().sendText("closing server session"); + serverSession.close(); + } catch (IOException ioe) { + ioe.printStackTrace(); + // Attempt to make the failure more obvious + throw new RuntimeException(ioe); + } + }); + } + + @Override + public void onError(Session session, Throwable t) { + t.printStackTrace(); + } + + public static boolean isServerSessionFullyClosed() { + return serverSession.isClosed(); + } + } + + + public static class WebSocketFastServerTimeout implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + container.setProcessPeriod(0); + } + } } \ No newline at end of file diff --git a/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainer.java b/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainerClient.java similarity index 82% rename from test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainer.java rename to test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainerClient.java index 8f5dfb6881aa..2f101f3473fb 100644 --- a/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainer.java +++ b/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainerClient.java @@ -35,14 +35,14 @@ * significantly extends the length of a test run when using multiple test * threads. */ -public class TestWsWebSocketContainerSessionExpiryContainer extends WsWebSocketContainerBaseTest { +public class TestWsWebSocketContainerSessionExpiryContainerClient extends WsWebSocketContainerBaseTest { @Test public void testSessionExpiryContainer() throws Exception { Tomcat tomcat = getTomcatInstance(); // No file system docBase required - Context ctx = tomcat.addContext("", null); + Context ctx = getProgrammaticRootContext(); ctx.addApplicationListener(TesterEchoServer.Config.class.getName()); Tomcat.addServlet(ctx, "default", new DefaultServlet()); ctx.addServletMappingDecoded("/", "default"); @@ -50,20 +50,16 @@ public void testSessionExpiryContainer() throws Exception { tomcat.start(); // Need access to implementation methods for configuring unit tests - WsWebSocketContainer wsContainer = (WsWebSocketContainer) - ContainerProvider.getWebSocketContainer(); + WsWebSocketContainer wsContainer = (WsWebSocketContainer) ContainerProvider.getWebSocketContainer(); // 5 second timeout wsContainer.setDefaultMaxSessionIdleTimeout(5000); wsContainer.setProcessPeriod(1); EndpointA endpointA = new EndpointA(); - connectToEchoServer(wsContainer, endpointA, - TesterEchoServer.Config.PATH_BASIC); - connectToEchoServer(wsContainer, endpointA, - TesterEchoServer.Config.PATH_BASIC); - Session s3a = connectToEchoServer(wsContainer, endpointA, - TesterEchoServer.Config.PATH_BASIC); + connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC); + connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC); + Session s3a = connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC); // Check all three sessions are open Set setA = s3a.getOpenSessions(); @@ -71,9 +67,9 @@ public void testSessionExpiryContainer() throws Exception { int count = 0; boolean isOpen = true; - while (isOpen && count < 8) { - count ++; - Thread.sleep(1000); + while (isOpen && count < 100) { + count++; + Thread.sleep(100); isOpen = false; for (Session session : setA) { if (session.isOpen()) { @@ -86,8 +82,7 @@ public void testSessionExpiryContainer() throws Exception { if (isOpen) { for (Session session : setA) { if (session.isOpen()) { - System.err.println("Session with ID [" + session.getId() + - "] is open"); + System.err.println("Session with ID [" + session.getId() + "] is open"); } } Assert.fail("There were open sessions"); diff --git a/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainerServer.java b/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainerServer.java new file mode 100644 index 000000000000..16fbf4135e91 --- /dev/null +++ b/test/org/apache/tomcat/websocket/TestWsWebSocketContainerSessionExpiryContainerServer.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.tomcat.websocket; + +import java.util.Set; + +import jakarta.servlet.ServletContextEvent; +import jakarta.servlet.ServletContextListener; +import jakarta.websocket.ContainerProvider; +import jakarta.websocket.Session; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.servlets.DefaultServlet; +import org.apache.catalina.startup.Tomcat; +import org.apache.tomcat.websocket.TestWsWebSocketContainer.EndpointA; +import org.apache.tomcat.websocket.server.Constants; +import org.apache.tomcat.websocket.server.WsServerContainer; + +/* + * Moved to separate test class to improve test concurrency. These tests are + * some of the last tests to start and having them all in a single class + * significantly extends the length of a test run when using multiple test + * threads. + */ +public class TestWsWebSocketContainerSessionExpiryContainerServer extends WsWebSocketContainerBaseTest { + + @Test + public void testSessionExpiryContainer() throws Exception { + + Tomcat tomcat = getTomcatInstance(); + // No file system docBase required + Context ctx = getProgrammaticRootContext(); + ctx.addApplicationListener(TesterEchoServer.Config.class.getName()); + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + ctx.addApplicationListener(WebSocketServerTimeoutConfig.class.getName()); + + tomcat.start(); + + // Need access to implementation methods for configuring unit tests + WsWebSocketContainer wsContainer = (WsWebSocketContainer) ContainerProvider.getWebSocketContainer(); + + EndpointA endpointA = new EndpointA(); + connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC); + connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC); + Session s3a = connectToEchoServer(wsContainer, endpointA, TesterEchoServer.Config.PATH_BASIC); + + // Check all three sessions are open + Set setA = s3a.getOpenSessions(); + Assert.assertEquals(3, setA.size()); + + int count = 0; + boolean isOpen = true; + while (isOpen && count < 100) { + count++; + Thread.sleep(100); + isOpen = false; + for (Session session : setA) { + if (session.isOpen()) { + isOpen = true; + break; + } + } + } + + if (isOpen) { + for (Session session : setA) { + if (session.isOpen()) { + System.err.println("Session with ID [" + session.getId() + "] is open"); + } + } + Assert.fail("There were open sessions"); + } + } + + + public static class WebSocketServerTimeoutConfig implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + // Process timeouts on every run of the background thread + container.setProcessPeriod(0); + // Set session timeout to 5s + container.setDefaultMaxSessionIdleTimeout(5000); + } + } +} diff --git a/test/org/apache/tomcat/websocket/WebSocketBaseTest.java b/test/org/apache/tomcat/websocket/WebSocketBaseTest.java index 3d55c9d16180..5e714ab1ec37 100644 --- a/test/org/apache/tomcat/websocket/WebSocketBaseTest.java +++ b/test/org/apache/tomcat/websocket/WebSocketBaseTest.java @@ -28,13 +28,11 @@ public abstract class WebSocketBaseTest extends TomcatBaseTest { - protected Tomcat startServer( - final Class configClass) - throws LifecycleException { + protected Tomcat startServer(final Class configClass) throws LifecycleException { Tomcat tomcat = getTomcatInstance(); // No file system docBase required - Context ctx = tomcat.addContext("", null); + Context ctx = getProgrammaticRootContext(); ctx.addApplicationListener(configClass.getName()); Tomcat.addServlet(ctx, "default", new DefaultServlet()); ctx.addServletMappingDecoded("/", "default"); diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml index 3708876aa34f..84b8d672c4bd 100644 --- a/webapps/docs/changelog.xml +++ b/webapps/docs/changelog.xml @@ -478,6 +478,11 @@ endpoint via a forward proxy that requires authentication. Based on a patch provided by Joe Mokos. (markt) + + Ensure that WebSocket connection closure completes if the connection is + closed when the server side has used the proprietary suspend/resume + feature to suspend the connection. (markt) + diff --git a/webapps/docs/web-socket-howto.xml b/webapps/docs/web-socket-howto.xml index 34bd6c5eb2ef..10922ca6746a 100644 --- a/webapps/docs/web-socket-howto.xml +++ b/webapps/docs/web-socket-howto.xml @@ -64,6 +64,13 @@ the timeout to use in milliseconds. For an infinite timeout, use -1.

+

The session close timeout defaults to 30000 milliseconds (30 seconds). This + may be changed by setting the property + org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT in the user + properties collection attached to the WebSocket session. The value assigned + to this property should be a Long and represents the timeout to + use in milliseconds. Values less than or equal to zero will be ignored.

+

In addition to the Session.setMaxIdleTimeout(long) method which is part of the Jakarta WebSocket API, Tomcat provides greater control of the timing out the session due to lack of activity. Setting the property