From 1840f6364b3bf9b6b89f0c849af8ea663fc8f4cb Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 30 Jul 2025 10:38:07 +0200 Subject: [PATCH] feat: implement MCP-compliant keep-alive functionality for server transports - Add KeepAliveScheduler utility class for configurable periodic session pings - Integrate keep-alive support in WebFlux, WebMVC, and HttpServlet SSE transport providers - Add keepAliveInterval configuration option to all transport provider builders - Deprecate existing constructors in favor of builder pattern with enhanced configuration - Update graceful shutdown to properly clean up keep-alive schedulers - Add unit tests for KeepAliveScheduler functionality Implements MCP specification recommendations for connection health detection: - Configurable ping frequency to suit different network environments - Optional keep-alive (disabled by default) to avoid excessive network overhead - Proper resource cleanup to prevent connection leaks https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/ping#implementation-considerations Resolves: #414, #158 Replaces #353 Signed-off-by: Christian Tzolov --- .../WebFluxSseServerTransportProvider.java | 83 ++++- ...FluxStreamableServerTransportProvider.java | 42 ++- .../WebMvcSseServerTransportProvider.java | 158 ++++++++- ...bMvcStreamableServerTransportProvider.java | 40 ++- .../client/McpAsyncClient.java | 7 +- ...HttpServletSseServerTransportProvider.java | 70 +++- ...vletStreamableServerTransportProvider.java | 46 ++- .../spec/McpStreamableServerSession.java | 23 +- .../util/KeepAliveScheduler.java | 216 +++++++++++++ ...HttpServletStreamableIntegrationTests.java | 1 + .../util/KeepAliveSchedulerTests.java | 303 ++++++++++++++++++ 11 files changed, 954 insertions(+), 35 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index fde067f03..b1b5246c8 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,6 +1,7 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; +import java.time.Duration; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; @@ -11,6 +12,8 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; @@ -109,6 +112,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + /** * Constructs a new WebFlux SSE server transport provider instance with the default * SSE endpoint. @@ -118,7 +127,10 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -131,7 +143,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); } @@ -145,9 +160,32 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param sseEndpoint The SSE endpoint path. Must not be null. + * @param keepAliveInterval The interval for sending keep-alive pings to clients. + * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -161,6 +199,17 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } } @Override @@ -209,15 +258,6 @@ public Mono notifyClients(String method, Object params) { /** * Initiates a graceful shutdown of all the sessions. This method ensures all active * sessions are properly closed and cleaned up. - * - *

- * The shutdown process: - *

    - *
  • Marks the transport as closing to prevent new connections
  • - *
  • Closes each active session
  • - *
  • Removes closed sessions from the sessions map
  • - *
  • Times out after 5 seconds if shutdown takes too long
  • - *
* @return A Mono that completes when all sessions have been closed */ @Override @@ -225,7 +265,14 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()) .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) .flatMap(McpServerSession::closeGracefully) - .then(); + .then() + .doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -396,6 +443,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private Duration keepAliveInterval; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -446,6 +495,17 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the interval for sending keep-alive pings to clients. + * @param keepAliveInterval The keep-alive interval duration. If null, keep-alive + * is disabled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -456,7 +516,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index e277e4749..79224a57d 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -12,6 +12,8 @@ import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -28,6 +30,7 @@ import reactor.core.publisher.Mono; import java.io.IOException; +import java.time.Duration; import java.util.List; import java.util.concurrent.ConcurrentHashMap; @@ -58,8 +61,11 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private volatile boolean isClosing = false; + private KeepAliveScheduler keepAliveScheduler; + private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, - McpTransportContextExtractor contextExtractor, boolean disallowDelete) { + McpTransportContextExtractor contextExtractor, boolean disallowDelete, + Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); @@ -73,6 +79,20 @@ private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Stri .POST(this.mcpEndpoint, this::handlePost) .DELETE(this.mcpEndpoint, this::handleDelete) .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + else { + logger.warn("Keep-alive interval is not set or invalid. No keep-alive will be scheduled."); + } + } @Override @@ -105,6 +125,11 @@ public Mono closeGracefully() { .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) .flatMap(McpStreamableServerSession::closeGracefully) .then(); + }).then().doOnSuccess(v -> { + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } }); } @@ -368,6 +393,8 @@ public static class Builder { private boolean disallowDelete; + private Duration keepAliveInterval; + private Builder() { // used by a static method } @@ -424,6 +451,17 @@ public Builder disallowDelete(boolean disallowDelete) { return this; } + /** + * Sets the keep-alive interval for the server transport. + * @param keepAliveInterval The interval for sending keep-alive messages. If null, + * no keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with * the configured settings. @@ -435,7 +473,7 @@ public WebFluxStreamableServerTransportProvider build() { Assert.notNull(mcpEndpoint, "Message endpoint must be set"); return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor, - disallowDelete); + disallowDelete, keepAliveInterval); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 5aa89d529..b90f9fb3d 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -18,6 +18,8 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -107,6 +109,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + private KeepAliveScheduler keepAliveScheduler; + /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE * endpoint. @@ -116,7 +120,10 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -130,7 +137,10 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, "", messageEndpoint, sseEndpoint); } @@ -146,9 +156,33 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * * @param keepAliveInterval The interval for sending keep-alive messages to + * @throws IllegalArgumentException if any parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -162,6 +196,17 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } } @Override @@ -209,10 +254,13 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()).doFirst(() -> { this.isClosing = true; logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }) - .flatMap(McpServerSession::closeGracefully) - .then() - .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); + }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -435,4 +483,106 @@ public void close() { } + /** + * Creates a new Builder instance for configuring and creating instances of + * WebMvcSseServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of WebMvcSseServerTransportProvider. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebMvcSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String baseUrl = ""; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private Duration keepAliveInterval; + + /** + * Sets the JSON object mapper to use for message serialization/deserialization. + * @param objectMapper The object mapper to use + * @return This builder instance for method chaining + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint path where clients will send their messages. + * @param messageEndpoint The message endpoint path + * @return This builder instance for method chaining + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the endpoint path where clients will establish SSE connections. + *

+ * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be + * used. + * @param sseEndpoint The SSE endpoint path + * @return This builder instance for method chaining + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

+ * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of WebMvcSseServerTransportProvider with the configured + * settings. + * @return A new WebMvcSseServerTransportProvider instance + * @throws IllegalStateException if objectMapper or messageEndpoint is not set + */ + public WebMvcSseServerTransportProvider build() { + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); + } + return new WebMvcSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval); + } + + } + } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index d14a51d87..391aa3e8d 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -33,6 +33,8 @@ import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -101,6 +103,8 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer */ private volatile boolean isClosing = false; + private KeepAliveScheduler keepAliveScheduler; + /** * Constructs a new WebMvcStreamableServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -113,7 +117,8 @@ public class WebMvcStreamableServerTransportProvider implements McpStreamableSer * @throws IllegalArgumentException if any parameter is null */ private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, - boolean disallowDelete, McpTransportContextExtractor contextExtractor) { + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); @@ -127,6 +132,19 @@ private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, Strin .POST(this.mcpEndpoint, this::handlePost) .DELETE(this.mcpEndpoint, this::handleDelete) .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + else { + logger.warn("Keep-alive interval is not set or invalid. No keep-alive will be scheduled."); + } } @Override @@ -184,6 +202,10 @@ public Mono closeGracefully() { this.sessions.clear(); logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } }); } @@ -584,6 +606,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private Duration keepAliveInterval; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -635,6 +659,18 @@ public Builder contextExtractor(McpTransportContextExtractor cont return this; } + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be created to periodically check and send keep-alive messages to clients. + * @param keepAliveInterval The interval duration for keep-alive messages, or null + * to disable keep-alive + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with * the configured settings. @@ -646,7 +682,7 @@ public WebMvcStreamableServerTransportProvider build() { Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete, - this.contextExtractor); + this.contextExtractor, this.keepAliveInterval); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 9e861deba..405e7123f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.client; import java.time.Duration; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -174,7 +176,10 @@ public class McpAsyncClient { Map> requestHandlers = new HashMap<>(); // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, params -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, params -> { + logger.debug("Received ping: {}", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); + return Mono.just(Map.of()); + }); // Roots List Request Handler if (this.clientCapabilities.roots() != null) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..5c0b85f26 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -6,6 +6,7 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.time.Duration; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -19,6 +20,7 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -103,6 +105,12 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Session factory for creating new sessions */ private McpServerSession.Factory sessionFactory; + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -110,7 +118,10 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement * serialization/deserialization * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); @@ -124,13 +135,47 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval) { + this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing.get()) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } } /** @@ -324,7 +369,13 @@ public Mono closeGracefully() { isClosing.set(true); logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -475,6 +526,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private Duration keepAliveInterval; + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param objectMapper The object mapper to use @@ -522,6 +575,18 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the interval for keep-alive pings. + *

+ * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -535,7 +600,8 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 4d2dc62f4..211a9c052 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -7,6 +7,7 @@ import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; @@ -28,12 +29,14 @@ import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -110,6 +113,12 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private volatile boolean isClosing = false; + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + /** * Constructs a new HttpServletStreamableServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization @@ -121,7 +130,8 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet * @throws IllegalArgumentException if any parameter is null */ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, - boolean disallowDelete, McpTransportContextExtractor contextExtractor) { + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); @@ -130,6 +140,18 @@ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } @Override @@ -187,6 +209,12 @@ public Mono closeGracefully() { this.sessions.clear(); logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } }); } @@ -737,6 +765,8 @@ public static class Builder { private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private Duration keepAliveInterval; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -784,6 +814,18 @@ public Builder contextExtractor(McpTransportContextExtractor return this; } + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be activated to periodically ping active sessions. + * @param keepAliveInterval The interval for keep-alive pings. If null, no + * keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} * with the configured settings. @@ -795,7 +837,7 @@ public HttpServletStreamableServerTransportProvider build() { Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, - this.disallowDelete, this.contextExtractor); + this.disallowDelete, this.contextExtractor, this.keepAliveInterval); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index f600f28b3..c9b041fd6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -1,26 +1,27 @@ package io.modelcontextprotocol.spec; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.fasterxml.jackson.core.type.TypeReference; + import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; - /** * Representation of a Streamable HTTP server session that keeps track of mapping * server-initiated requests to the client and mapping arriving responses. It also allows diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java new file mode 100644 index 000000000..30e8a2c2a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -0,0 +1,216 @@ +/** + * Copyright 2025 - 2025 the original author or authors. + */ +package io.modelcontextprotocol.util; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * A utility class for scheduling regular keep-alive calls to maintain connections. It + * sends periodic keep-alive, ping, messages to connected mcp clients to prevent idle + * timeouts. + * + * The pings are sent to all active mcp sessions at regular intervals. + * + * @author Christian Tzolov + */ +public class KeepAliveScheduler { + + private static final Logger logger = LoggerFactory.getLogger(KeepAliveScheduler.class); + + private static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + /** Initial delay before the first keepAlive call */ + private final Duration initialDelay; + + /** Interval between subsequent keepAlive calls */ + private final Duration interval; + + /** The scheduler used for executing keepAlive calls */ + private final Scheduler scheduler; + + /** The current state of the scheduler */ + private final AtomicBoolean isRunning = new AtomicBoolean(false); + + /** The current subscription for the keepAlive calls */ + private Disposable currentSubscription; + + // TODO Currently we do not support the streams (streamable http session created by + // http post/get) + + /** Supplier for reactive McpSession instances */ + private final Supplier> mcpSessions; + + /** + * Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a + * supplier for McpSession instances. + * @param scheduler The scheduler to use for executing keepAlive calls + * @param initialDelay Initial delay before the first keepAlive call + * @param interval Interval between subsequent keepAlive calls + * @param mcpSessions Supplier for McpSession instances + */ + KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval, + Supplier> mcpSessions) { + this.scheduler = scheduler; + this.initialDelay = initialDelay; + this.interval = interval; + this.mcpSessions = mcpSessions; + } + + /** + * Creates a new Builder instance for constructing KeepAliveScheduler. + * @return A new Builder instance + */ + public static Builder builder(Supplier> mcpSessions) { + return new Builder(mcpSessions); + } + + /** + * Starts regular keepAlive calls with sessions supplier. + * @return Disposable to control the scheduled execution + */ + public Disposable start() { + if (this.isRunning.compareAndSet(false, true)) { + + this.currentSubscription = Flux.interval(this.initialDelay, this.interval, this.scheduler) + .doOnNext(tick -> { + this.mcpSessions.get() + .flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF) + .doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session, + e.getMessage())) + .onErrorComplete()) + .subscribe(); + }) + .doOnCancel(() -> this.isRunning.set(false)) + .doOnComplete(() -> this.isRunning.set(false)) + .onErrorComplete(error -> { + logger.error("KeepAlive scheduler error", error); + this.isRunning.set(false); + return true; + }) + .subscribe(); + + return this.currentSubscription; + } + else { + throw new IllegalStateException("KeepAlive scheduler is already running. Stop it first."); + } + } + + /** + * Stops the currently running keepAlive scheduler. + */ + public void stop() { + if (this.currentSubscription != null && !this.currentSubscription.isDisposed()) { + this.currentSubscription.dispose(); + } + this.isRunning.set(false); + } + + /** + * Checks if the scheduler is currently running. + * @return true if running, false otherwise + */ + public boolean isRunning() { + return this.isRunning.get(); + } + + /** + * Shuts down the scheduler and releases resources. + */ + public void shutdown() { + stop(); + if (this.scheduler instanceof Disposable) { + ((Disposable) this.scheduler).dispose(); + } + } + + /** + * Builder class for creating KeepAliveScheduler instances with fluent API. + */ + public static class Builder { + + private Scheduler scheduler = Schedulers.boundedElastic(); + + private Duration initialDelay = Duration.ofSeconds(0); + + private Duration interval = Duration.ofSeconds(30); + + private Supplier> mcpSessions; + + /** + * Creates a new Builder instance with a supplier for McpSession instances. + * @param mcpSessions The supplier for McpSession instances + */ + Builder(Supplier> mcpSessions) { + Assert.notNull(mcpSessions, "McpSessions supplier must not be null"); + this.mcpSessions = mcpSessions; + } + + /** + * Sets the scheduler to use for executing keepAlive calls. + * @param scheduler The scheduler to use: + *
    + *
  • Schedulers.single() - single-threaded scheduler
  • + *
  • Schedulers.boundedElastic() - bounded elastic scheduler for I/O operations + * (Default)
  • + *
  • Schedulers.parallel() - parallel scheduler for CPU-intensive + * operations
  • + *
  • Schedulers.immediate() - immediate scheduler for synchronous execution
  • + *
+ * @return This builder instance for method chaining + */ + public Builder scheduler(Scheduler scheduler) { + Assert.notNull(scheduler, "Scheduler must not be null"); + this.scheduler = scheduler; + return this; + } + + /** + * Sets the initial delay before the first keepAlive call. + * @param initialDelay The initial delay duration + * @return This builder instance for method chaining + */ + public Builder initialDelay(Duration initialDelay) { + Assert.notNull(initialDelay, "Initial delay must not be null"); + this.initialDelay = initialDelay; + return this; + } + + /** + * Sets the interval between subsequent keepAlive calls. + * @param interval The interval duration + * @return This builder instance for method chaining + */ + public Builder interval(Duration interval) { + Assert.notNull(interval, "Interval must not be null"); + this.interval = interval; + return this; + } + + /** + * Builds and returns a new KeepAliveScheduler instance. + * @return A new KeepAliveScheduler configured with the builder's settings + */ + public KeepAliveScheduler build() { + return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions); + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 3377f98a6..ecb0c33c3 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -38,6 +38,7 @@ public void before() { mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() .objectMapper(new ObjectMapper()) .mcpEndpoint(MESSAGE_ENDPOINT) + .keepAliveInterval(Duration.ofSeconds(1)) .build(); tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java new file mode 100644 index 000000000..4de9363c2 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -0,0 +1,303 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.scheduler.VirtualTimeScheduler; + +/** + * Unit tests for {@link KeepAliveScheduler}. + * + * @author Christian Tzolov + */ +class KeepAliveSchedulerTests { + + private MockMcpSession mockSession1; + + private MockMcpSession mockSession2; + + private Supplier> mockSessionsSupplier; + + private VirtualTimeScheduler virtualTimeScheduler; + + @BeforeEach + void setUp() { + virtualTimeScheduler = VirtualTimeScheduler.create(); + mockSession1 = new MockMcpSession(); + mockSession2 = new MockMcpSession(); + mockSessionsSupplier = () -> Flux.just(mockSession1); + } + + @AfterEach + void tearDown() { + if (virtualTimeScheduler != null) { + virtualTimeScheduler.dispose(); + } + } + + @Test + void testBuilderWithNullSessionsSupplier() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("McpSessions supplier must not be null"); + } + + @Test + void testBuilderWithNullScheduler() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).scheduler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Scheduler must not be null"); + } + + @Test + void testBuilderWithNullInitialDelay() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).initialDelay(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Initial delay must not be null"); + } + + @Test + void testBuilderWithNullInterval() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).interval(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Interval must not be null"); + } + + @Test + void testBuilderDefaults() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier).build(); + + assertThat(scheduler).isNotNull(); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testStartWithMultipleSessions() { + mockSessionsSupplier = () -> Flux.just(mockSession1, mockSession2); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + assertThat(scheduler.isRunning()).isFalse(); + + // Start the scheduler + Disposable disposable = scheduler.start(); + + assertThat(scheduler.isRunning()).isTrue(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + + // Advance time to trigger the first ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify both sessions received ping + assertThat(mockSession1.getPingCount()).isEqualTo(1); + assertThat(mockSession2.getPingCount()).isEqualTo(1); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Second ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Third ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Fourth ping + + // Verify second ping was sent + assertThat(mockSession1.getPingCount()).isEqualTo(4); + assertThat(mockSession2.getPingCount()).isEqualTo(4); + + // Clean up + scheduler.stop(); + + assertThat(scheduler.isRunning()).isFalse(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isTrue(); + } + + @Test + void testStartWithEmptySessionsList() { + mockSessionsSupplier = () -> Flux.empty(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger ping attempts + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify no sessions were called (since list was empty) + assertThat(mockSession1.getPingCount()).isEqualTo(0); + assertThat(mockSession2.getPingCount()).isEqualTo(0); + + // Clean up + scheduler.stop(); + } + + @Test + void testStartWhenAlreadyRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + + // Try to start again - should throw exception + assertThatThrownBy(scheduler::start).isInstanceOf(IllegalStateException.class) + .hasMessage("KeepAlive scheduler is already running. Stop it first."); + + // Clean up + scheduler.stop(); + } + + @Test + void testStopWhenNotRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Should not throw exception when stopping a non-running scheduler + assertDoesNotThrow(scheduler::stop); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testShutdown() { + // Setup with a separate virtual time scheduler (which is disposable) + VirtualTimeScheduler separateScheduler = VirtualTimeScheduler.create(); + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(separateScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + assertThat(scheduler.isRunning()).isTrue(); + + // Shutdown should stop the scheduler and dispose the scheduler + scheduler.shutdown(); + assertThat(scheduler.isRunning()).isFalse(); + assertThat(separateScheduler.isDisposed()).isTrue(); + } + + @Test + void testPingFailureHandling() { + // Setup session that fails ping + mockSession1.setShouldFailPing(true); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger the ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify ping was attempted (error should be handled gracefully) + assertThat(mockSession1.getPingCount()).isEqualTo(1); + + // Scheduler should still be running despite the error + assertThat(scheduler.isRunning()).isTrue(); + + // Clean up + scheduler.stop(); + } + + @Test + void testDisposableReturnedFromStart() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start and get disposable + Disposable disposable = scheduler.start(); + + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + assertThat(scheduler.isRunning()).isTrue(); + + // Dispose directly through the returned disposable + disposable.dispose(); + + assertThat(disposable.isDisposed()).isTrue(); + assertThat(scheduler.isRunning()).isFalse(); + } + + /** + * Simple mock implementation of McpSession for testing purposes. + */ + private static class MockMcpSession implements McpSession { + + private final AtomicInteger pingCount = new AtomicInteger(0); + + private boolean shouldFailPing = false; + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + if (McpSchema.METHOD_PING.equals(method)) { + pingCount.incrementAndGet(); + if (shouldFailPing) { + return Mono.error(new RuntimeException("Connection failed")); + } + return Mono.just((T) new Object()); + } + return Mono.empty(); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + // No-op for mock + } + + public int getPingCount() { + return pingCount.get(); + } + + public void setShouldFailPing(boolean shouldFailPing) { + this.shouldFailPing = shouldFailPing; + } + + @Override + public String toString() { + return "MockMcpSession"; + } + + } + +}