Skip to content

Commit e01edc1

Browse files
authored
Fix slow shutdown when a Streamable HTTP client is connected (#843)
* Add "Async" suffix to async methods
1 parent d914581 commit e01edc1

File tree

9 files changed

+103
-43
lines changed

9 files changed

+103
-43
lines changed

src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.AspNetCore.Http;
22
using Microsoft.AspNetCore.Http.Features;
33
using Microsoft.AspNetCore.WebUtilities;
4+
using Microsoft.Extensions.Hosting;
45
using Microsoft.Extensions.Logging;
56
using Microsoft.Extensions.Options;
67
using Microsoft.Net.Http.Headers;
@@ -17,8 +18,9 @@ internal sealed class StreamableHttpHandler(
1718
IOptionsFactory<McpServerOptions> mcpServerOptionsFactory,
1819
IOptions<HttpServerTransportOptions> httpServerTransportOptions,
1920
StatefulSessionManager sessionManager,
20-
ILoggerFactory loggerFactory,
21-
IServiceProvider applicationServices)
21+
IHostApplicationLifetime hostApplicationLifetime,
22+
IServiceProvider applicationServices,
23+
ILoggerFactory loggerFactory)
2224
{
2325
private const string McpSessionIdHeaderName = "Mcp-Session-Id";
2426

@@ -60,7 +62,7 @@ await WriteJsonRpcErrorAsync(context,
6062
}
6163

6264
InitializeSseResponse(context);
63-
var wroteResponse = await session.Transport.HandlePostRequest(message, context.Response.Body, context.RequestAborted);
65+
var wroteResponse = await session.Transport.HandlePostRequestAsync(message, context.Response.Body, context.RequestAborted);
6466
if (!wroteResponse)
6567
{
6668
// We wound up writing nothing, so there should be no Content-Type response header.
@@ -94,14 +96,28 @@ await WriteJsonRpcErrorAsync(context,
9496
return;
9597
}
9698

97-
await using var _ = await session.AcquireReferenceAsync(context.RequestAborted);
98-
InitializeSseResponse(context);
99+
// Link the GET request to both RequestAborted and ApplicationStopping.
100+
// The GET request should complete immediately during graceful shutdown without waiting for
101+
// in-flight POST requests to complete. This prevents slow shutdown when clients are still connected.
102+
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
103+
var cancellationToken = sseCts.Token;
99104

100-
// We should flush headers to indicate a 200 success quickly, because the initialization response
101-
// will be sent in response to a different POST request. It might be a while before we send a message
102-
// over this response body.
103-
await context.Response.Body.FlushAsync(context.RequestAborted);
104-
await session.Transport.HandleGetRequest(context.Response.Body, context.RequestAborted);
105+
try
106+
{
107+
await using var _ = await session.AcquireReferenceAsync(cancellationToken);
108+
InitializeSseResponse(context);
109+
110+
// We should flush headers to indicate a 200 success quickly, because the initialization response
111+
// will be sent in response to a different POST request. It might be a while before we send a message
112+
// over this response body.
113+
await context.Response.Body.FlushAsync(cancellationToken);
114+
await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken);
115+
}
116+
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
117+
{
118+
// RequestAborted always triggers when the client disconnects before a complete response body is written,
119+
// but this is how SSE connections are typically closed.
120+
}
105121
}
106122

107123
public async Task HandleDeleteRequestAsync(HttpContext context)

src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,27 @@ private async Task ReceiveUnsolicitedMessagesAsync()
175175
request.Headers.Accept.Add(s_textEventStreamMediaType);
176176
CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, SessionId, _negotiatedProtocolVersion);
177177

178-
using var response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false);
179-
180-
if (!response.IsSuccessStatusCode)
178+
// Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages.
179+
HttpResponseMessage response;
180+
try
181+
{
182+
response = await _httpClient.SendAsync(request, message: null, _connectionCts.Token).ConfigureAwait(false);
183+
}
184+
catch (HttpRequestException)
181185
{
182-
// Server support for the GET request is optional. If it fails, we don't care. It just means we won't receive unsolicited messages.
183186
return;
184187
}
185188

186-
using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false);
187-
await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false);
189+
using (response)
190+
{
191+
if (!response.IsSuccessStatusCode)
192+
{
193+
return;
194+
}
195+
196+
using var responseStream = await response.Content.ReadAsStreamAsync(_connectionCts.Token).ConfigureAwait(false);
197+
await ProcessSseResponseAsync(responseStream, relatedRpcRequest: null, _connectionCts.Token).ConfigureAwait(false);
198+
}
188199
}
189200

190201
private async Task<JsonRpcMessageWithId?> ProcessSseResponseAsync(Stream responseStream, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)

src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private protected JsonRpcMessage()
4040
/// <remarks>
4141
/// This property should only be set when implementing a custom <see cref="ITransport"/>
4242
/// that needs to pass additional per-message context or to pass a <see cref="JsonRpcMessageContext.User"/>
43-
/// to <see cref="StreamableHttpServerTransport.HandlePostRequest(JsonRpcMessage, Stream, CancellationToken)"/>
43+
/// to <see cref="StreamableHttpServerTransport.HandlePostRequestAsync(JsonRpcMessage, Stream, CancellationToken)"/>
4444
/// or <see cref="SseResponseStreamTransport.OnMessageReceivedAsync(JsonRpcMessage, CancellationToken)"/> .
4545
/// </remarks>
4646
[JsonIgnore]

src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ public sealed class StreamableHttpServerTransport : ITransport
4343
/// <summary>
4444
/// Configures whether the transport should be in stateless mode that does not require all requests for a given session
4545
/// to arrive to the same ASP.NET Core application process. Unsolicited server-to-client messages are not supported in this mode,
46-
/// so calling <see cref="HandleGetRequest(Stream, CancellationToken)"/> results in an <see cref="InvalidOperationException"/>.
46+
/// so calling <see cref="HandleGetRequestAsync(Stream, CancellationToken)"/> results in an <see cref="InvalidOperationException"/>.
4747
/// Server-to-client requests are also unsupported, because the responses may arrive at another ASP.NET Core application process.
4848
/// Client sampling and roots capabilities are also disabled in stateless mode, because the server cannot make requests.
4949
/// </summary>
5050
public bool Stateless { get; init; }
5151

5252
/// <summary>
53-
/// Gets a value indicating whether the execution context should flow from the calls to <see cref="HandlePostRequest(JsonRpcMessage, Stream, CancellationToken)"/>
53+
/// Gets a value indicating whether the execution context should flow from the calls to <see cref="HandlePostRequestAsync(JsonRpcMessage, Stream, CancellationToken)"/>
5454
/// to the corresponding <see cref="JsonRpcMessageContext.ExecutionContext"/> property contained in the <see cref="JsonRpcMessage"/> instances returned by the <see cref="MessageReader"/>.
5555
/// </summary>
5656
/// <remarks>
@@ -76,7 +76,7 @@ public sealed class StreamableHttpServerTransport : ITransport
7676
/// <param name="sseResponseStream">The response stream to write MCP JSON-RPC messages as SSE events to.</param>
7777
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
7878
/// <returns>A task representing the send loop that writes JSON-RPC messages to the SSE response stream.</returns>
79-
public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken = default)
79+
public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationToken cancellationToken = default)
8080
{
8181
Throw.IfNull(sseResponseStream);
8282

@@ -111,7 +111,7 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c
111111
/// If 's an authenticated <see cref="ClaimsPrincipal"/> sent the message, that can be included in the <see cref="JsonRpcMessage.Context"/>.
112112
/// No other part of the context should be set.
113113
/// </para>
114-
public async Task<bool> HandlePostRequest(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default)
114+
public async Task<bool> HandlePostRequestAsync(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default)
115115
{
116116
Throw.IfNull(message);
117117
Throw.IfNull(responseStream);

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,10 @@ public async Task StreamableHttpClient_SendsMcpProtocolVersionHeader_AfterInitia
181181

182182
await mcpClient.DisposeAsync();
183183

184-
// The header should be included in the GET request, the initialized notification, the tools/list call, and the delete request.
185-
// The DELETE request won't be sent for Stateless mode due to the lack of an Mcp-Session-Id.
186-
Assert.Equal(Stateless ? 3 : 4, protocolVersionHeaderValues.Count);
184+
// The GET request might not have started in time, and the DELETE request won't be sent in
185+
// Stateless mode due to the lack of an Mcp-Session-Id, but the header should be included in the
186+
// initialized notification and the tools/list call at a minimum.
187+
Assert.True(protocolVersionHeaderValues.Count > 1);
187188
Assert.All(protocolVersionHeaderValues, v => Assert.Equal("2025-03-26", v));
188189
}
189190
}

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using ModelContextProtocol.Server;
99
using ModelContextProtocol.Tests.Utils;
1010
using System.ComponentModel;
11+
using System.Diagnostics;
1112
using System.Net;
1213
using System.Security.Claims;
1314

@@ -114,10 +115,9 @@ public async Task Messages_FromNewUser_AreRejected()
114115
}
115116

116117
[Fact]
117-
public async Task ClaimsPrincipal_CanBeInjectedIntoToolMethod()
118+
public async Task ClaimsPrincipal_CanBeInjected_IntoToolMethod()
118119
{
119120
Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools<ClaimsPrincipalTools>();
120-
Builder.Services.AddHttpContextAccessor();
121121

122122
await using var app = Builder.Build();
123123

@@ -211,6 +211,35 @@ public async Task Sampling_DoesNotCloseStream_Prematurely()
211211
m.Message.Contains("request '2' for method 'sampling/createMessage'"));
212212
}
213213

214+
[Fact]
215+
public async Task Server_ShutsDownQuickly_WhenClientIsConnected()
216+
{
217+
Builder.Services.AddMcpServer().WithHttpTransport().WithTools<ClaimsPrincipalTools>();
218+
219+
await using var app = Builder.Build();
220+
app.MapMcp();
221+
222+
await app.StartAsync(TestContext.Current.CancellationToken);
223+
224+
// Connect a client which will open a long-running GET request (SSE or Streamable HTTP)
225+
await using var mcpClient = await ConnectAsync();
226+
227+
// Verify the client is connected
228+
var tools = await mcpClient.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken);
229+
Assert.NotEmpty(tools);
230+
231+
// Now measure how long it takes to stop the server
232+
var stopwatch = Stopwatch.StartNew();
233+
await app.StopAsync(TestContext.Current.CancellationToken);
234+
stopwatch.Stop();
235+
236+
// The server should shut down quickly (within a few seconds). We use 5 seconds as a generous threshold.
237+
// This is much less than the default HostOptions.ShutdownTimeout of 30 seconds.
238+
Assert.True(stopwatch.Elapsed < TimeSpan.FromSeconds(5),
239+
$"Server took {stopwatch.Elapsed.TotalSeconds:F2} seconds to shut down with a connected client. " +
240+
"This suggests the GET request is not respecting ApplicationStopping token.");
241+
}
242+
214243
private ClaimsPrincipal CreateUser(string name)
215244
=> new ClaimsPrincipal(new ClaimsIdentity(
216245
[new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)],

tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryConnection.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,29 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils;
66

77
public sealed class KestrelInMemoryConnection : ConnectionContext
88
{
9-
private readonly Pipe _clientToServerPipe = new();
10-
private readonly Pipe _serverToClientPipe = new();
119
private readonly CancellationTokenSource _connectionClosedCts = new();
1210
private readonly FeatureCollection _features = new();
1311

1412
public KestrelInMemoryConnection()
1513
{
14+
Pipe clientToServerPipe = new();
15+
Pipe serverToClientPipe = new();
16+
1617
ConnectionClosed = _connectionClosedCts.Token;
1718
Transport = new DuplexPipe
1819
{
19-
Input = _clientToServerPipe.Reader,
20-
Output = _serverToClientPipe.Writer,
20+
Input = clientToServerPipe.Reader,
21+
Output = serverToClientPipe.Writer,
2122
};
22-
Application = new DuplexPipe
23+
ClientPipe = new DuplexPipe
2324
{
24-
Input = _serverToClientPipe.Reader,
25-
Output = _clientToServerPipe.Writer,
25+
Input = serverToClientPipe.Reader,
26+
Output = clientToServerPipe.Writer,
2627
};
27-
ClientStream = new DuplexStream(Application, _connectionClosedCts);
28+
ClientStream = new DuplexStream(ClientPipe, _connectionClosedCts);
2829
}
2930

30-
public IDuplexPipe Application { get; }
31+
public IDuplexPipe ClientPipe { get; }
3132
public Stream ClientStream { get; }
3233

3334
public override IDuplexPipe Transport { get; set; }
@@ -41,8 +42,8 @@ public override async ValueTask DisposeAsync()
4142
{
4243
// This is called by Kestrel. The client should dispose the DuplexStream which
4344
// completes the other half of these pipes.
44-
await _serverToClientPipe.Writer.CompleteAsync();
45-
await _serverToClientPipe.Reader.CompleteAsync();
45+
await Transport.Input.CompleteAsync();
46+
await Transport.Output.CompleteAsync();
4647

4748
// Don't bother disposing the _connectionClosedCts, since this is just for testing,
4849
// and it's annoying to synchronize with DuplexStream.

tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTransport.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ public sealed class KestrelInMemoryTransport : IConnectionListenerFactory
1313
public KestrelInMemoryConnection CreateConnection(EndPoint endpoint)
1414
{
1515
var connection = new KestrelInMemoryConnection();
16-
GetAcceptQueue(endpoint).Writer.TryWrite(connection);
16+
if (!GetAcceptQueue(endpoint).Writer.TryWrite(connection))
17+
{
18+
throw new IOException("The KestrelInMemoryTransport has been shut down.");
19+
};
20+
1721
return connection;
1822
}
1923

@@ -37,12 +41,9 @@ private sealed class KestrelInMemoryListener(EndPoint endpoint, Channel<Connecti
3741

3842
public async ValueTask<ConnectionContext?> AcceptAsync(CancellationToken cancellationToken = default)
3943
{
40-
if (await acceptQueue.Reader.WaitToReadAsync(cancellationToken))
44+
await foreach (var item in acceptQueue.Reader.ReadAllAsync(cancellationToken))
4145
{
42-
while (acceptQueue.Reader.TryRead(out var item))
43-
{
44-
return item;
45-
}
46+
return item;
4647
}
4748

4849
return null;

tests/ModelContextProtocol.TestSseServer/Program.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ private static void HandleStatelessMcp(IApplicationBuilder app)
370370
var serviceCollection = new ServiceCollection();
371371
serviceCollection.AddLogging();
372372
serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService<ILoggerFactory>());
373+
serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService<IHostApplicationLifetime>());
373374
serviceCollection.AddSingleton(app.ApplicationServices.GetRequiredService<DiagnosticListener>());
374375
serviceCollection.AddRoutingCore();
375376

0 commit comments

Comments
 (0)