diff --git a/crates/rmcp/src/transport/common/server_side_http.rs b/crates/rmcp/src/transport/common/server_side_http.rs index 693d8b34..51cd51f5 100644 --- a/crates/rmcp/src/transport/common/server_side_http.rs +++ b/crates/rmcp/src/transport/common/server_side_http.rs @@ -6,6 +6,7 @@ use http::Response; use http_body::Body; use http_body_util::{BodyExt, Empty, Full, combinators::BoxBody}; use sse_stream::{KeepAlive, Sse, SseBody}; +use tokio_util::sync::CancellationToken; use super::http_header::EVENT_STREAM_MIME_TYPE; use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; @@ -65,20 +66,26 @@ pub struct ServerSseMessage { pub(crate) fn sse_stream_response( stream: impl futures::Stream + Send + Sync + 'static, keep_alive: Option, + ct: CancellationToken, ) -> Response> { use futures::StreamExt; - let stream = SseBody::new(stream.map(|message| { - let data = serde_json::to_string(&message.message).expect("valid message"); - let mut sse = Sse::default().data(data); - sse.id = message.event_id; - Result::::Ok(sse) - })); + let stream = stream + .map(|message| { + let data = serde_json::to_string(&message.message).expect("valid message"); + let mut sse = Sse::default().data(data); + sse.id = message.event_id; + Result::::Ok(sse) + }) + .take_until(async move { ct.cancelled().await }); + let stream = SseBody::new(stream); + let stream = match keep_alive { Some(duration) => stream .with_keep_alive::(KeepAlive::new().interval(duration)) .boxed(), None => stream.boxed(), }; + Response::builder() .status(http::StatusCode::OK) .header(http::header::CONTENT_TYPE, EVENT_STREAM_MIME_TYPE) diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index ba373d48..08789566 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -6,6 +6,7 @@ use http::{Method, Request, Response, header::ALLOW}; use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use super::session::SessionManager; use crate::{ @@ -33,6 +34,11 @@ pub struct StreamableHttpServerConfig { pub sse_keep_alive: Option, /// If true, the server will create a session for each request and keep it alive. pub stateful_mode: bool, + /// Cancellation token for the Streamable HTTP server. + /// + /// When this token is cancelled, all active sessions are terminated and + /// the server stops accepting new requests. + pub cancellation_token: CancellationToken, } impl Default for StreamableHttpServerConfig { @@ -40,6 +46,7 @@ impl Default for StreamableHttpServerConfig { Self { sse_keep_alive: Some(Duration::from_secs(15)), stateful_mode: true, + cancellation_token: CancellationToken::new(), } } } @@ -209,7 +216,11 @@ where .resume(&session_id, last_event_id) .await .map_err(internal_error_response("resume session"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + Ok(sse_stream_response( + stream, + self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), + )) } else { // create standalone stream let stream = self @@ -217,7 +228,11 @@ where .create_standalone_stream(&session_id) .await .map_err(internal_error_response("create standalone stream"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + Ok(sse_stream_response( + stream, + self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), + )) } } @@ -307,7 +322,11 @@ where .create_stream(&session_id, message) .await .map_err(internal_error_response("get session"))?; - Ok(sse_stream_response(stream, self.config.sse_keep_alive)) + Ok(sse_stream_response( + stream, + self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), + )) } ClientJsonRpcMessage::Notification(_) | ClientJsonRpcMessage::Response(_) @@ -380,6 +399,7 @@ where } }), self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), ); response.headers_mut().insert( @@ -413,6 +433,7 @@ where } }), self.config.sse_keep_alive, + self.config.cancellation_token.child_token(), )) } ClientJsonRpcMessage::Notification(_notification) => { diff --git a/crates/rmcp/tests/test_with_js.rs b/crates/rmcp/tests/test_with_js.rs index 3f2761cd..f399bbe8 100644 --- a/crates/rmcp/tests/test_with_js.rs +++ b/crates/rmcp/tests/test_with_js.rs @@ -94,6 +94,7 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { .wait() .await?; + let ct = CancellationToken::new(); let service: StreamableHttpService = StreamableHttpService::new( || Ok(Calculator::new()), @@ -101,11 +102,12 @@ async fn test_with_js_streamable_http_client() -> anyhow::Result<()> { StreamableHttpServerConfig { stateful_mode: true, sse_keep_alive: None, + cancellation_token: ct.child_token(), }, ); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(STREAMABLE_HTTP_BIND_ADDRESS).await?; - let ct = CancellationToken::new(); + let handle = tokio::spawn({ let ct = ct.clone(); async move { diff --git a/examples/servers/src/counter_streamhttp.rs b/examples/servers/src/counter_streamhttp.rs index ff00cec6..db9b9df1 100644 --- a/examples/servers/src/counter_streamhttp.rs +++ b/examples/servers/src/counter_streamhttp.rs @@ -1,5 +1,5 @@ use rmcp::transport::streamable_http_server::{ - StreamableHttpService, session::local::LocalSessionManager, + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, }; use tracing_subscriber::{ layer::SubscriberExt, @@ -20,17 +20,24 @@ async fn main() -> anyhow::Result<()> { ) .with(tracing_subscriber::fmt::layer()) .init(); + let ct = tokio_util::sync::CancellationToken::new(); let service = StreamableHttpService::new( || Ok(Counter::new()), LocalSessionManager::default().into(), - Default::default(), + StreamableHttpServerConfig { + cancellation_token: ct.child_token(), + ..Default::default() + }, ); let router = axum::Router::new().nest_service("/mcp", service); let tcp_listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; let _ = axum::serve(tcp_listener, router) - .with_graceful_shutdown(async { tokio::signal::ctrl_c().await.unwrap() }) + .with_graceful_shutdown(async move { + tokio::signal::ctrl_c().await.unwrap(); + ct.cancel(); + }) .await; Ok(()) } diff --git a/examples/transport/src/named-pipe.rs b/examples/transport/src/named-pipe.rs index c091a95f..c472fad6 100644 --- a/examples/transport/src/named-pipe.rs +++ b/examples/transport/src/named-pipe.rs @@ -12,11 +12,11 @@ async fn main() -> anyhow::Result<()> { let mut server = ServerOptions::new() .first_pipe_instance(true) .create(name)?; - while let Ok(_) = server.connect().await { + while server.connect().await.is_ok() { let stream = server; server = ServerOptions::new().create(name)?; tokio::spawn(async move { - match serve_server(Calculator, stream).await { + match serve_server(Calculator::new(), stream).await { Ok(server) => { println!("Server initialized successfully"); if let Err(e) = server.waiting().await {