diff --git a/mcp/streamable.go b/mcp/streamable.go index d371c87..bb341a5 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -9,11 +9,14 @@ import ( "context" "fmt" "io" + "math" + "math/rand/v2" "net/http" "strconv" "strings" "sync" "sync/atomic" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" @@ -594,12 +597,39 @@ type StreamableClientTransport struct { opts StreamableClientTransportOptions } +// StreamableReconnectOptions defines parameters for client reconnect attempts. +type StreamableReconnectOptions struct { + // MaxRetries is the maximum number of times to attempt a reconnect before giving up. + // A value of 0 or less means never retry. + MaxRetries int + + // growFactor is the multiplicative factor by which the delay increases after each attempt. + // A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. + // It must be 1.0 or greater if MaxRetries is greater than 0. + growFactor float64 + + // initialDelay is the base delay for the first reconnect attempt. + initialDelay time.Duration + + // maxDelay caps the backoff delay, preventing it from growing indefinitely. + maxDelay time.Duration +} + +// DefaultReconnectOptions provides sensible defaults for reconnect logic. +var DefaultReconnectOptions = &StreamableReconnectOptions{ + MaxRetries: 5, + growFactor: 1.5, + initialDelay: 1 * time.Second, + maxDelay: 30 * time.Second, +} + // StreamableClientTransportOptions provides options for the // [NewStreamableClientTransport] constructor. type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. - HTTPClient *http.Client + HTTPClient *http.Client + ReconnectOptions *StreamableReconnectOptions } // NewStreamableClientTransport returns a new client transport that connects to @@ -625,22 +655,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er if client == nil { client = http.DefaultClient } - return &streamableClientConn{ - url: t.url, - client: client, - incoming: make(chan []byte, 100), - done: make(chan struct{}), - }, nil + reconnOpts := t.opts.ReconnectOptions + if reconnOpts == nil { + reconnOpts = DefaultReconnectOptions + } + // Create a new cancellable context that will manage the connection's lifecycle. + // This is crucial for cleanly shutting down the background SSE listener by + // cancelling its blocking network operations, which prevents hangs on exit. + connCtx, cancel := context.WithCancel(context.Background()) + conn := &streamableClientConn{ + url: t.url, + client: client, + incoming: make(chan []byte, 100), + done: make(chan struct{}), + ReconnectOptions: reconnOpts, + ctx: connCtx, + cancel: cancel, + } + return conn, nil } type streamableClientConn struct { - url string - client *http.Client - incoming chan []byte - done chan struct{} + url string + client *http.Client + incoming chan []byte + done chan struct{} + ReconnectOptions *StreamableReconnectOptions closeOnce sync.Once closeErr error + ctx context.Context + cancel context.CancelFunc mu sync.Mutex protocolVersion string @@ -662,6 +707,12 @@ func (c *streamableClientConn) SessionID() string { // Read implements the [Connection] interface. func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { + s.mu.Lock() + err := s.err + s.mu.Unlock() + if err != nil { + return nil, err + } select { case <-ctx.Done(): return nil, ctx.Err() @@ -701,14 +752,26 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e return err } + // The session has just been initialized. if sessionID == "" { // locked s._sessionID = gotSessionID + // Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. + // This can be used to open an SSE stream, allowing the server to + // communicate to the client, without the client first sending data via + // HTTP POST. + go s.establishSSE(&startSSEState{}) } return nil } +// startSSEState holds the state for initiating an SSE stream. +type startSSEState struct { + lastEventID string + attempt int +} + func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { data, err := jsonrpc2.EncodeMessage(msg) if err != nil { @@ -742,7 +805,8 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string sessionID = resp.Header.Get(sessionIDHeader) switch ct := resp.Header.Get("Content-Type"); ct { case "text/event-stream": - go s.handleSSE(resp) + // Section 2.1: The SSE stream is initiated after a POST. + go s.handleSSE(resp, &startSSEState{}) case "application/json": // TODO: read the body and send to s.incoming (in a select that also receives from s.done). resp.Body.Close() @@ -754,7 +818,9 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string return sessionID, nil } -func (s *streamableClientConn) handleSSE(resp *http.Response) { +// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel. +// If the stream breaks, it uses the last received event ID to automatically trigger the reconnect logic. +func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEState) { defer resp.Body.Close() done := make(chan struct{}) @@ -762,9 +828,10 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) { defer close(done) for evt, err := range scanEvents(resp.Body) { if err != nil { - // TODO: surface this error; possibly break the stream + s.scheduleReconnect(opts) return } + opts.lastEventID = evt.id select { case <-s.done: return @@ -782,6 +849,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) { // Close implements the [Connection] interface. func (s *streamableClientConn) Close() error { s.closeOnce.Do(func() { + // Cancel any hanging network requests. + s.cancel() close(s.done) req, err := http.NewRequest(http.MethodDelete, s.url, nil) @@ -800,3 +869,86 @@ func (s *streamableClientConn) Close() error { }) return s.closeErr } + +// establishSSE creates and manages the persistent SSE listening stream. +// It is used for both the initial connection and all subsequent reconnect attempts, +// using the Last-Event-ID header to resume a broken stream where it left off. +// On a successful response, it delegates to handleSSE to process events; +// on failure, it triggers the client's reconnect logic. +func (s *streamableClientConn) establishSSE(opts *startSSEState) { + select { + case <-s.done: + return + default: + } + + req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil) + if err != nil { + return + } + s.mu.Lock() + if s._sessionID != "" { + req.Header.Set("Mcp-Session-Id", s._sessionID) + } + s.mu.Unlock() + if opts.lastEventID != "" { + req.Header.Set("Last-Event-ID", opts.lastEventID) + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := s.client.Do(req) + if err != nil { + // On connection error, schedule a retry. + s.scheduleReconnect(opts) + return + } + + // Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. + if resp.StatusCode == http.StatusMethodNotAllowed { + resp.Body.Close() + return + } + + if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { + resp.Body.Close() + return + } + + s.handleSSE(resp, opts) +} + +// scheduleReconnect schedules the next SSE reconnect attempt after a delay. +func (s *streamableClientConn) scheduleReconnect(opts *startSSEState) { + reconnOpts := s.ReconnectOptions + if opts.attempt >= reconnOpts.MaxRetries { + s.mu.Lock() + s.err = fmt.Errorf("connection failed after %d attempts", reconnOpts.MaxRetries) + s.mu.Unlock() + s.Close() // Close the connection to unblock any readers. + return + } + + delay := calculateReconnectDelay(reconnOpts, opts.attempt) + + select { + case <-s.done: + return + case <-time.After(delay): + opts.attempt++ + s.establishSSE(opts) + } +} + +// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. +func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { + // Calculate the exponential backoff using the grow factor. + backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) + + // Cap the backoffDuration at maxDelay. + backoffDuration = min(backoffDuration, opts.maxDelay) + + // Use a full jitter using backoffDuration + jitter := rand.N(backoffDuration) + + return backoffDuration + jitter +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e0e85a1..882dfa0 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -10,14 +10,17 @@ import ( "encoding/json" "fmt" "io" + "net" "net/http" "net/http/cookiejar" "net/http/httptest" + "net/http/httputil" "net/url" "strings" "sync" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -105,6 +108,149 @@ func TestStreamableTransports(t *testing.T) { } } +// TestClientReplayAfterProxyBreak verifies that the client can recover from a +// mid-stream network failure and receive replayed messages. It uses a proxy +// that is killed and restarted to simulate a recoverable network outage. +func TestClientReplayAfterProxyBreak(t *testing.T) { + // 1. Configure the real MCP server. + server := NewServer(testImpl, nil) + + // Use a channel to synchronize the server's message sending with the test's + // proxy-killing action. + serverReadyToKillProxy := make(chan struct{}) + var serverClosed sync.WaitGroup + AddTool(server, &Tool{Name: "multiMessageTool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { + go func() { + bgCtx := context.Background() + // Send the first two messages immediately. + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg1"}) + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg2"}) + + // Signal the test that it can now kill the proxy. + serverClosed.Add(1) + close(serverReadyToKillProxy) + // Wait for the test to kill the proxy before sending the rest. + serverClosed.Wait() + + // These messages should be queued for replay by the server after + // the client's connection drops. + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg3"}) + _ = ss.NotifyProgress(bgCtx, &ProgressNotificationParams{Message: "msg4"}) + }() + return &CallToolResultFor[any]{}, nil + }) + realServer := httptest.NewServer(NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, nil)) + defer realServer.Close() + realServerURL, err := url.Parse(realServer.URL) + if err != nil { + t.Fatalf("Failed to parse real server URL: %v", err) + } + + // 2. Configure a proxy that sits between the client and the real server. + proxyHandler := httputil.NewSingleHostReverseProxy(realServerURL) + proxy := httptest.NewServer(proxyHandler) + proxyAddr := proxy.Listener.Addr().String() // Get the address to restart it later. + + // 3. Configure the client to connect to the proxy with default options. + clientTransport := NewStreamableClientTransport(proxy.URL, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // 4. Connect, perform handshake, and trigger the tool. + conn, err := clientTransport.Connect(ctx) + if err != nil { + t.Fatalf("Connect() failed: %v", err) + } + + // Perform handshake. + initReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(100), Method: "initialize", Params: mustMarshal(t, &InitializeParams{})} + if err := conn.Write(ctx, initReq); err != nil { + t.Fatalf("Write(initialize) failed: %v", err) + } + if _, err := conn.Read(ctx); err != nil { + t.Fatalf("Read(initialize resp) failed: %v", err) + } + if err := conn.Write(ctx, &jsonrpc.Request{Method: "initialized", Params: mustMarshal(t, &InitializedParams{})}); err != nil { + t.Fatalf("Write(initialized) failed: %v", err) + } + + callReq := &jsonrpc.Request{ID: jsonrpc2.Int64ID(1), Method: "tools/call", Params: mustMarshal(t, &CallToolParams{Name: "multiMessageTool"})} + if err := conn.Write(ctx, callReq); err != nil { + t.Fatalf("Write(tool/call) failed: %v", err) + } + + // 5. Read and verify messages until the server signals it's ready for the proxy kill. + receivedNotifications := readProgressNotifications(t, ctx, conn, 2) + + wantReceived := []jsonrpc.Message{ + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg1"})}, + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg2"})}, + } + transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() }) + + if diff := cmp.Diff(wantReceived, receivedNotifications, transform); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } + + select { + case <-serverReadyToKillProxy: + // Server has sent the first two messages and is paused. + case <-ctx.Done(): + t.Fatalf("Context timed out before server was ready to kill proxy") + } + + // 6. Simulate a total network failure by closing the proxy. + t.Log("--- Killing proxy to simulate network failure ---") + proxy.CloseClientConnections() + proxy.Close() + serverClosed.Done() + + // 7. Simulate network recovery by restarting the proxy on the same address. + t.Logf("--- Restarting proxy on %s ---", proxyAddr) + listener, err := net.Listen("tcp", proxyAddr) + if err != nil { + t.Fatalf("Failed to listen on proxy address: %v", err) + } + restartedProxy := &http.Server{Handler: proxyHandler} + go restartedProxy.Serve(listener) + defer restartedProxy.Close() + + // 8. Continue reading from the same connection object. + // Its internal logic should successfully retry, reconnect to the new proxy, + // and receive the replayed messages. + recoveredNotifications := readProgressNotifications(t, ctx, conn, 2) + + // 9. Verify the correct messages were received on the recovered connection. + wantRecovered := []jsonrpc.Message{ + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg3"})}, + &jsonrpc.Request{Method: "notifications/progress", Params: mustMarshal(t, &ProgressNotificationParams{Message: "msg4"})}, + } + + if diff := cmp.Diff(wantRecovered, recoveredNotifications, transform); diff != "" { + t.Errorf("Recovered notifications mismatch (-want +got):\n%s", diff) + } +} + +// Helper to read a specific number of progress notifications. +func readProgressNotifications(t *testing.T, ctx context.Context, conn Connection, count int) []jsonrpc.Message { + t.Helper() + var notifications []jsonrpc.Message + for len(notifications) < count && ctx.Err() == nil { + msg, err := conn.Read(ctx) + if err != nil { + t.Fatalf("Failed to read notification: %v", err) + } + if req, ok := msg.(*jsonrpc.Request); ok && req.Method == "notifications/progress" { + notifications = append(notifications, req) + } + } + if len(notifications) != count { + t.Fatalf("Expected to read %d notifications, but got %d", count, len(notifications)) + } + return notifications +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP