Skip to content

mcp/streamable: add resumability for the Streamable transport #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 166 additions & 14 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@ import (
"context"
"fmt"
"io"
"math"
"math/rand/v2"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
Expand Down Expand Up @@ -594,12 +597,39 @@ type StreamableClientTransport struct {
opts StreamableClientTransportOptions
}

// StreamableReconnectOptions defines parameters for client reconnect attempts.
type StreamableReconnectOptions struct {
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
// A value of 0 or less means never retry.
MaxRetries int

// growFactor is the multiplicative factor by which the delay increases after each attempt.
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
// It must be 1.0 or greater if MaxRetries is greater than 0.
growFactor float64

// initialDelay is the base delay for the first reconnect attempt.
initialDelay time.Duration

// maxDelay caps the backoff delay, preventing it from growing indefinitely.
maxDelay time.Duration
}

// DefaultReconnectOptions provides sensible defaults for reconnect logic.
var DefaultReconnectOptions = &StreamableReconnectOptions{
MaxRetries: 5,
growFactor: 1.5,
initialDelay: 1 * time.Second,
maxDelay: 30 * time.Second,
}

// StreamableClientTransportOptions provides options for the
// [NewStreamableClientTransport] constructor.
type StreamableClientTransportOptions struct {
// HTTPClient is the client to use for making HTTP requests. If nil,
// http.DefaultClient is used.
HTTPClient *http.Client
HTTPClient *http.Client
ReconnectOptions *StreamableReconnectOptions
}

// NewStreamableClientTransport returns a new client transport that connects to
Expand All @@ -625,22 +655,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
if client == nil {
client = http.DefaultClient
}
return &streamableClientConn{
url: t.url,
client: client,
incoming: make(chan []byte, 100),
done: make(chan struct{}),
}, nil
reconnOpts := t.opts.ReconnectOptions
if reconnOpts == nil {
reconnOpts = DefaultReconnectOptions
}
// Create a new cancellable context that will manage the connection's lifecycle.
// This is crucial for cleanly shutting down the background SSE listener by
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(context.Background())
conn := &streamableClientConn{
url: t.url,
client: client,
incoming: make(chan []byte, 100),
done: make(chan struct{}),
ReconnectOptions: reconnOpts,
ctx: connCtx,
cancel: cancel,
}
return conn, nil
}

type streamableClientConn struct {
url string
client *http.Client
incoming chan []byte
done chan struct{}
url string
client *http.Client
incoming chan []byte
done chan struct{}
ReconnectOptions *StreamableReconnectOptions

closeOnce sync.Once
closeErr error
ctx context.Context
cancel context.CancelFunc

mu sync.Mutex
protocolVersion string
Expand All @@ -662,6 +707,12 @@ func (c *streamableClientConn) SessionID() string {

// Read implements the [Connection] interface.
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) {
s.mu.Lock()
err := s.err
s.mu.Unlock()
if err != nil {
return nil, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
Expand Down Expand Up @@ -701,14 +752,26 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
return err
}

// The session has just been initialized.
if sessionID == "" {
// locked
s._sessionID = gotSessionID
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint.
// This can be used to open an SSE stream, allowing the server to
// communicate to the client, without the client first sending data via
// HTTP POST.
go s.establishSSE(&startSSEState{})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goroutine is not present in the current code. Why is it needed here?

More general question: is there a description of the design of resumability somewhere? Not looking for a design doc, just a short sketch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the previous logic only covered section 2.1

Which established an SSE stream if the server responds to a client POST with a text/event-stream.

However in section 2.2, it says: The client MAY issue an HTTP GET to the MCP endpoint. This can be used to open an SSE stream, allowing the server to communicate to the client, without the client first sending data via HTTP POST

So from my understanding, we need to proactively issue a client GET request after initialization to see if the server has any SSE streams waiting before a client POST.

Added a comment.

}

return nil
}

// startSSEState holds the state for initiating an SSE stream.
type startSSEState struct {
lastEventID string
attempt int
}

func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) {
data, err := jsonrpc2.EncodeMessage(msg)
if err != nil {
Expand Down Expand Up @@ -742,7 +805,8 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
sessionID = resp.Header.Get(sessionIDHeader)
switch ct := resp.Header.Get("Content-Type"); ct {
case "text/event-stream":
go s.handleSSE(resp)
// Section 2.1: The SSE stream is initiated after a POST.
go s.handleSSE(resp, &startSSEState{})
case "application/json":
// TODO: read the body and send to s.incoming (in a select that also receives from s.done).
resp.Body.Close()
Expand All @@ -754,17 +818,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string
return sessionID, nil
}

func (s *streamableClientConn) handleSSE(resp *http.Response) {
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel.
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnect logic.
func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEState) {
defer resp.Body.Close()

done := make(chan struct{})
go func() {
defer close(done)
for evt, err := range scanEvents(resp.Body) {
if err != nil {
// TODO: surface this error; possibly break the stream
s.scheduleReconnect(opts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not every error is going to be worth reconnecting for, is it? There could be bugs or bad data. We'll probably need a way to distinguish these. For a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll follow-up with another CL after this one.

return
}
opts.lastEventID = evt.id
select {
case <-s.done:
return
Expand All @@ -782,6 +849,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) {
// Close implements the [Connection] interface.
func (s *streamableClientConn) Close() error {
s.closeOnce.Do(func() {
// Cancel any hanging network requests.
s.cancel()
close(s.done)

req, err := http.NewRequest(http.MethodDelete, s.url, nil)
Expand All @@ -800,3 +869,86 @@ func (s *streamableClientConn) Close() error {
})
return s.closeErr
}

// establishSSE creates and manages the persistent SSE listening stream.
// It is used for both the initial connection and all subsequent reconnect attempts,
// using the Last-Event-ID header to resume a broken stream where it left off.
// On a successful response, it delegates to handleSSE to process events;
// on failure, it triggers the client's reconnect logic.
func (s *streamableClientConn) establishSSE(opts *startSSEState) {
select {
case <-s.done:
return
default:
}

req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil)
if err != nil {
return
}
s.mu.Lock()
if s._sessionID != "" {
req.Header.Set("Mcp-Session-Id", s._sessionID)
}
s.mu.Unlock()
if opts.lastEventID != "" {
req.Header.Set("Last-Event-ID", opts.lastEventID)
}
req.Header.Set("Accept", "text/event-stream")

resp, err := s.client.Do(req)
if err != nil {
// On connection error, schedule a retry.
s.scheduleReconnect(opts)
return
}

// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
if resp.StatusCode == http.StatusMethodNotAllowed {
resp.Body.Close()
return
}

if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
resp.Body.Close()
return
}

s.handleSSE(resp, opts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You started a goroutine with this in postMessage. Why are you calling it again here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the client GET initiated stream.

I think there are 2 flows to start an SSE stream:

  1. client GET request (here)
  2. client POST request (postMessage)

}

// scheduleReconnect schedules the next SSE reconnect attempt after a delay.
func (s *streamableClientConn) scheduleReconnect(opts *startSSEState) {
reconnOpts := s.ReconnectOptions
if opts.attempt >= reconnOpts.MaxRetries {
s.mu.Lock()
s.err = fmt.Errorf("connection failed after %d attempts", reconnOpts.MaxRetries)
s.mu.Unlock()
s.Close() // Close the connection to unblock any readers.
return
}

delay := calculateReconnectDelay(reconnOpts, opts.attempt)

select {
case <-s.done:
return
case <-time.After(delay):
opts.attempt++
s.establishSSE(opts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a recursive call to establishSSE, so in theory you could blow the stack.
Not sure what the right fix is.

}
}

// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration {
// Calculate the exponential backoff using the grow factor.
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt)))

// Cap the backoffDuration at maxDelay.
backoffDuration = min(backoffDuration, opts.maxDelay)

// Use a full jitter using backoffDuration
jitter := rand.N(backoffDuration)

return backoffDuration + jitter
}
Loading
Loading