From 0052b374684b613b0c849899b325ebe334ac6501 Mon Sep 17 00:00:00 2001 From: Mark Thomas Date: Thu, 18 Jan 2024 11:32:43 +0000 Subject: [PATCH] Refactor WebSocket close for suspend/resume Ensure that WebSocket connection closure completes if the connection is closed when the server side has used the proprietary suspend/resume feature to suspend the connection. --- .../apache/tomcat/websocket/Constants.java | 6 ++ .../apache/tomcat/websocket/WsSession.java | 67 +++++++++++-- .../websocket/WsWebSocketContainer.java | 9 +- .../websocket/server/WsServerContainer.java | 2 +- .../websocket/TestWsSessionSuspendResume.java | 99 +++++++++++++++++++ webapps/docs/changelog.xml | 5 + webapps/docs/web-socket-howto.xml | 7 ++ 7 files changed, 187 insertions(+), 8 deletions(-) diff --git a/java/org/apache/tomcat/websocket/Constants.java b/java/org/apache/tomcat/websocket/Constants.java index 85a6d1149c8a..16f3f8184fcb 100644 --- a/java/org/apache/tomcat/websocket/Constants.java +++ b/java/org/apache/tomcat/websocket/Constants.java @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.TimeUnit; import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.Extension; @@ -117,6 +118,11 @@ public class Constants { // Milliseconds so this is 20 seconds public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + // Configuration for session close timeout + public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT"; + // Default is 30 seconds - setting is in milliseconds + public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30); + // Configuration for read idle timeout on WebSocket session public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS"; diff --git a/java/org/apache/tomcat/websocket/WsSession.java b/java/org/apache/tomcat/websocket/WsSession.java index 309b08611887..0c1b2d18dd92 100644 --- a/java/org/apache/tomcat/websocket/WsSession.java +++ b/java/org/apache/tomcat/websocket/WsSession.java @@ -27,6 +27,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -115,6 +116,7 @@ public class WsSession implements Session { private volatile long lastActiveRead = System.currentTimeMillis(); private volatile long lastActiveWrite = System.currentTimeMillis(); private Map futures = new ConcurrentHashMap<>(); + private volatile Long sessionCloseTimeoutExpiry; /** @@ -593,7 +595,14 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal */ state.set(State.CLOSED); // ... and close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); + } else { + /* + * Set close timeout. If the client fails to send a close message response within the timeout, the session + * and the connection will be closed when the timeout expires. + */ + sessionCloseTimeoutExpiry = + Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout())); } // Fail any uncompleted messages. @@ -632,7 +641,7 @@ public void onClose(CloseReason closeReason) { state.set(State.CLOSED); // Close the network connection. - wsRemoteEndpoint.close(); + closeConnection(); } else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) { /* * The local endpoint sent a close message the the same time as the remote endpoint. The local close is @@ -644,12 +653,55 @@ public void onClose(CloseReason closeReason) { * The local endpoint sent the first close message. The remote endpoint has now responded with its own close * message so mark the session as fully closed and close the network connection. */ - wsRemoteEndpoint.close(); + closeConnection(); } // CLOSING and CLOSED are NO-OPs } + private void closeConnection() { + /* + * Close the network connection. + */ + wsRemoteEndpoint.close(); + /* + * Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for + * tracking the session close timeout. + */ + webSocketContainer.unregisterSession(getSessionMapKey(), this); + } + + + /* + * Returns the session close timeout in milliseconds + */ + protected long getSessionCloseTimeout() { + long result = 0; + Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY); + if (obj instanceof Long) { + result = ((Long) obj).intValue(); + } + if (result <= 0) { + result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT; + } + return result; + } + + + protected void checkCloseTimeout() { + // Skip the check if no session close timeout has been set. + if (sessionCloseTimeoutExpiry != null) { + // Check if the timeout has expired. + if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) { + // Check if the session has been closed in another thread while the timeout was being processed. + if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) { + closeConnection(); + } + } + } + } + + private void fireEndpointOnClose(CloseReason closeReason) { // Fire the onClose event @@ -722,7 +774,7 @@ private void sendCloseMessage(CloseReason closeReason) { if (log.isDebugEnabled()) { log.debug(sm.getString("wsSession.sendCloseFail", id), e); } - wsRemoteEndpoint.close(); + closeConnection(); // Failure to send a close message is not unexpected in the case of // an abnormal closure (usually triggered by a failure to read/write // from/to the client. In this case do not trigger the endpoint's @@ -730,8 +782,6 @@ private void sendCloseMessage(CloseReason closeReason) { if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { localEndpoint.onError(this, e); } - } finally { - webSocketContainer.unregisterSession(getSessionMapKey(), this); } } @@ -864,6 +914,11 @@ public String getQueryString() { @Override public Principal getUserPrincipal() { checkState(); + return getUserPrincipalInternal(); + } + + + public Principal getUserPrincipalInternal() { return userPrincipal; } diff --git a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java index df92c67ff32a..a5aa9667af07 100644 --- a/java/org/apache/tomcat/websocket/WsWebSocketContainer.java +++ b/java/org/apache/tomcat/websocket/WsWebSocketContainer.java @@ -610,7 +610,12 @@ Set getOpenSessions(Object key) { synchronized (endPointSessionMapLock) { Set sessions = endpointSessionMap.get(key); if (sessions != null) { - result.addAll(sessions); + // Some sessions may be in the process of closing + for (WsSession session : sessions) { + if (session.isOpen()) { + result.add(session); + } + } } } return result; @@ -1060,8 +1065,10 @@ public void backgroundProcess() { if (backgroundProcessCount >= processPeriod) { backgroundProcessCount = 0; + // Check all registered sessions. for (WsSession wsSession : sessions.keySet()) { wsSession.checkExpiration(); + wsSession.checkCloseTimeout(); } } diff --git a/java/org/apache/tomcat/websocket/server/WsServerContainer.java b/java/org/apache/tomcat/websocket/server/WsServerContainer.java index 8fb4eb967ca0..b3b37ca45640 100644 --- a/java/org/apache/tomcat/websocket/server/WsServerContainer.java +++ b/java/org/apache/tomcat/websocket/server/WsServerContainer.java @@ -349,7 +349,7 @@ protected void registerSession(Object key, WsSession wsSession) { */ @Override protected void unregisterSession(Object key, WsSession wsSession) { - if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) { + if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) { unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId()); } super.unregisterSession(key, wsSession); diff --git a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java index cb54821662a8..f624f5c87c95 100644 --- a/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java +++ b/test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java @@ -23,6 +23,8 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import jakarta.servlet.ServletContextEvent; +import jakarta.servlet.ServletContextListener; import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.CloseReason; import jakarta.websocket.ContainerProvider; @@ -39,7 +41,9 @@ import org.apache.catalina.servlets.DefaultServlet; import org.apache.catalina.startup.Tomcat; import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint; +import org.apache.tomcat.websocket.server.Constants; import org.apache.tomcat.websocket.server.TesterEndpointConfig; +import org.apache.tomcat.websocket.server.WsServerContainer; public class TestWsSessionSuspendResume extends WebSocketBaseTest { @@ -141,4 +145,99 @@ void addMessage(String message) { } } } + + + @Test + public void testSuspendThenClose() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + ctx.addApplicationListener(SuspendCloseConfig.class.getName()); + ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName()); + + Tomcat.addServlet(ctx, "default", new DefaultServlet()); + ctx.addServletMappingDecoded("/", "default"); + + tomcat.start(); + + WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer(); + + ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build(); + Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig, + new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH)); + + wsSession.getBasicRemote().sendText("start test"); + + // Wait for the client response to be received by the server + int count = 0; + while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) { + Thread.sleep(100); + count ++; + } + Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed()); + } + + + public static final class SuspendCloseConfig extends TesterEndpointConfig { + private static final String PATH = "/echo"; + + @Override + protected Class getEndpointClass() { + return SuspendCloseEndpoint.class; + } + + @Override + protected ServerEndpointConfig getServerEndpointConfig() { + return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build(); + } + } + + + public static final class SuspendCloseEndpoint extends Endpoint { + + // Yes, a static variable is a hack. + private static WsSession serverSession; + + @Override + public void onOpen(Session session, EndpointConfig epc) { + serverSession = (WsSession) session; + // Set a short session close timeout (milliseconds) + serverSession.getUserProperties().put( + org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000)); + // Any message will trigger the suspend then close + serverSession.addMessageHandler(String.class, message -> { + try { + serverSession.getBasicRemote().sendText("server session open"); + serverSession.getBasicRemote().sendText("suspending server session"); + serverSession.suspend(); + serverSession.getBasicRemote().sendText("closing server session"); + serverSession.close(); + } catch (IOException ioe) { + ioe.printStackTrace(); + // Attempt to make the failure more obvious + throw new RuntimeException(ioe); + } + }); + } + + @Override + public void onError(Session session, Throwable t) { + t.printStackTrace(); + } + + public static boolean isServerSessionFullyClosed() { + return serverSession.isClosed(); + } + } + + + public static class WebSocketFastServerTimeout implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + container.setProcessPeriod(0); + } + } } \ No newline at end of file diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml index 4a852bfd0f5e..eeb1bbba8e30 100644 --- a/webapps/docs/changelog.xml +++ b/webapps/docs/changelog.xml @@ -205,6 +205,11 @@ Review usage of debug logging and downgrade trace or data dumping operations from debug level to trace. (remm) + + Ensure that WebSocket connection closure completes if the connection is + closed when the server side has used the proprietary suspend/resume + feature to suspend the connection. (markt) + diff --git a/webapps/docs/web-socket-howto.xml b/webapps/docs/web-socket-howto.xml index 20cf2caf383b..2aaa808df2e6 100644 --- a/webapps/docs/web-socket-howto.xml +++ b/webapps/docs/web-socket-howto.xml @@ -64,6 +64,13 @@ the timeout to use in milliseconds. For an infinite timeout, use -1.

+

The session close timeout defaults to 30000 milliseconds (30 seconds). This + may be changed by setting the property + org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT in the user + properties collection attached to the WebSocket session. The value assigned + to this property should be a Long and represents the timeout to + use in milliseconds. Values less than or equal to zero will be ignored.

+

In addition to the Session.setMaxIdleTimeout(long) method which is part of the Jakarta WebSocket API, Tomcat provides greater control of the timing out the session due to lack of activity. Setting the property