-
Notifications
You must be signed in to change notification settings - Fork 75
mcp/streamable: add resumability for the Streamable transport #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,11 +9,14 @@ import ( | |
"context" | ||
"fmt" | ||
"io" | ||
"math" | ||
"math/rand/v2" | ||
"net/http" | ||
"strconv" | ||
"strings" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" | ||
"github.com/modelcontextprotocol/go-sdk/jsonrpc" | ||
|
@@ -594,12 +597,39 @@ type StreamableClientTransport struct { | |
opts StreamableClientTransportOptions | ||
} | ||
|
||
// StreamableReconnectOptions defines parameters for client reconnect attempts. | ||
type StreamableReconnectOptions struct { | ||
// MaxRetries is the maximum number of times to attempt a reconnect before giving up. | ||
// A value of 0 or less means never retry. | ||
MaxRetries int | ||
|
||
// growFactor is the multiplicative factor by which the delay increases after each attempt. | ||
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time. | ||
// It must be 1.0 or greater if MaxRetries is greater than 0. | ||
growFactor float64 | ||
|
||
// initialDelay is the base delay for the first reconnect attempt. | ||
initialDelay time.Duration | ||
|
||
// maxDelay caps the backoff delay, preventing it from growing indefinitely. | ||
maxDelay time.Duration | ||
} | ||
|
||
// DefaultReconnectOptions provides sensible defaults for reconnect logic. | ||
var DefaultReconnectOptions = &StreamableReconnectOptions{ | ||
MaxRetries: 5, | ||
growFactor: 1.5, | ||
initialDelay: 1 * time.Second, | ||
maxDelay: 30 * time.Second, | ||
} | ||
|
||
// StreamableClientTransportOptions provides options for the | ||
// [NewStreamableClientTransport] constructor. | ||
type StreamableClientTransportOptions struct { | ||
// HTTPClient is the client to use for making HTTP requests. If nil, | ||
// http.DefaultClient is used. | ||
HTTPClient *http.Client | ||
HTTPClient *http.Client | ||
ReconnectOptions *StreamableReconnectOptions | ||
} | ||
|
||
// NewStreamableClientTransport returns a new client transport that connects to | ||
|
@@ -625,22 +655,37 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er | |
if client == nil { | ||
client = http.DefaultClient | ||
} | ||
return &streamableClientConn{ | ||
url: t.url, | ||
client: client, | ||
incoming: make(chan []byte, 100), | ||
done: make(chan struct{}), | ||
}, nil | ||
reconnOpts := t.opts.ReconnectOptions | ||
if reconnOpts == nil { | ||
reconnOpts = DefaultReconnectOptions | ||
} | ||
// Create a new cancellable context that will manage the connection's lifecycle. | ||
// This is crucial for cleanly shutting down the background SSE listener by | ||
// cancelling its blocking network operations, which prevents hangs on exit. | ||
connCtx, cancel := context.WithCancel(context.Background()) | ||
conn := &streamableClientConn{ | ||
url: t.url, | ||
client: client, | ||
incoming: make(chan []byte, 100), | ||
done: make(chan struct{}), | ||
ReconnectOptions: reconnOpts, | ||
ctx: connCtx, | ||
cancel: cancel, | ||
} | ||
return conn, nil | ||
} | ||
|
||
type streamableClientConn struct { | ||
url string | ||
client *http.Client | ||
incoming chan []byte | ||
done chan struct{} | ||
url string | ||
client *http.Client | ||
incoming chan []byte | ||
done chan struct{} | ||
ReconnectOptions *StreamableReconnectOptions | ||
|
||
closeOnce sync.Once | ||
closeErr error | ||
ctx context.Context | ||
cancel context.CancelFunc | ||
|
||
mu sync.Mutex | ||
protocolVersion string | ||
|
@@ -662,6 +707,12 @@ func (c *streamableClientConn) SessionID() string { | |
|
||
// Read implements the [Connection] interface. | ||
func (s *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error) { | ||
s.mu.Lock() | ||
err := s.err | ||
s.mu.Unlock() | ||
if err != nil { | ||
return nil, err | ||
} | ||
select { | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
|
@@ -701,14 +752,26 @@ func (s *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e | |
return err | ||
} | ||
|
||
// The session has just been initialized. | ||
if sessionID == "" { | ||
// locked | ||
s._sessionID = gotSessionID | ||
// Section 2.2: The client MAY issue an HTTP GET to the MCP endpoint. | ||
// This can be used to open an SSE stream, allowing the server to | ||
// communicate to the client, without the client first sending data via | ||
// HTTP POST. | ||
go s.establishSSE(&startSSEState{}) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// startSSEState holds the state for initiating an SSE stream. | ||
type startSSEState struct { | ||
lastEventID string | ||
attempt int | ||
} | ||
|
||
func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string, msg jsonrpc.Message) (string, error) { | ||
data, err := jsonrpc2.EncodeMessage(msg) | ||
if err != nil { | ||
|
@@ -742,7 +805,8 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string | |
sessionID = resp.Header.Get(sessionIDHeader) | ||
switch ct := resp.Header.Get("Content-Type"); ct { | ||
case "text/event-stream": | ||
go s.handleSSE(resp) | ||
// Section 2.1: The SSE stream is initiated after a POST. | ||
go s.handleSSE(resp, &startSSEState{}) | ||
case "application/json": | ||
// TODO: read the body and send to s.incoming (in a select that also receives from s.done). | ||
resp.Body.Close() | ||
|
@@ -754,17 +818,20 @@ func (s *streamableClientConn) postMessage(ctx context.Context, sessionID string | |
return sessionID, nil | ||
} | ||
|
||
func (s *streamableClientConn) handleSSE(resp *http.Response) { | ||
// handleSSE processes an incoming Server-Sent Events stream, pushing received messages to the client's channel. | ||
// If the stream breaks, it uses the last received event ID to automatically trigger the reconnect logic. | ||
func (s *streamableClientConn) handleSSE(resp *http.Response, opts *startSSEState) { | ||
defer resp.Body.Close() | ||
|
||
done := make(chan struct{}) | ||
go func() { | ||
defer close(done) | ||
for evt, err := range scanEvents(resp.Body) { | ||
if err != nil { | ||
// TODO: surface this error; possibly break the stream | ||
s.scheduleReconnect(opts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not every error is going to be worth reconnecting for, is it? There could be bugs or bad data. We'll probably need a way to distinguish these. For a later PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I'll follow-up with another CL after this one. |
||
return | ||
} | ||
opts.lastEventID = evt.id | ||
select { | ||
case <-s.done: | ||
return | ||
|
@@ -782,6 +849,8 @@ func (s *streamableClientConn) handleSSE(resp *http.Response) { | |
// Close implements the [Connection] interface. | ||
func (s *streamableClientConn) Close() error { | ||
s.closeOnce.Do(func() { | ||
// Cancel any hanging network requests. | ||
s.cancel() | ||
close(s.done) | ||
|
||
req, err := http.NewRequest(http.MethodDelete, s.url, nil) | ||
|
@@ -800,3 +869,86 @@ func (s *streamableClientConn) Close() error { | |
}) | ||
return s.closeErr | ||
} | ||
|
||
// establishSSE creates and manages the persistent SSE listening stream. | ||
// It is used for both the initial connection and all subsequent reconnect attempts, | ||
// using the Last-Event-ID header to resume a broken stream where it left off. | ||
// On a successful response, it delegates to handleSSE to process events; | ||
// on failure, it triggers the client's reconnect logic. | ||
func (s *streamableClientConn) establishSSE(opts *startSSEState) { | ||
select { | ||
case <-s.done: | ||
return | ||
default: | ||
} | ||
|
||
req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.url, nil) | ||
if err != nil { | ||
return | ||
} | ||
s.mu.Lock() | ||
if s._sessionID != "" { | ||
req.Header.Set("Mcp-Session-Id", s._sessionID) | ||
} | ||
s.mu.Unlock() | ||
if opts.lastEventID != "" { | ||
req.Header.Set("Last-Event-ID", opts.lastEventID) | ||
} | ||
req.Header.Set("Accept", "text/event-stream") | ||
|
||
resp, err := s.client.Do(req) | ||
if err != nil { | ||
// On connection error, schedule a retry. | ||
s.scheduleReconnect(opts) | ||
return | ||
} | ||
|
||
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint. | ||
if resp.StatusCode == http.StatusMethodNotAllowed { | ||
resp.Body.Close() | ||
return | ||
} | ||
|
||
if !strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { | ||
resp.Body.Close() | ||
return | ||
} | ||
|
||
s.handleSSE(resp, opts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You started a goroutine with this in postMessage. Why are you calling it again here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for the client GET initiated stream. I think there are 2 flows to start an SSE stream:
|
||
} | ||
|
||
// scheduleReconnect schedules the next SSE reconnect attempt after a delay. | ||
func (s *streamableClientConn) scheduleReconnect(opts *startSSEState) { | ||
reconnOpts := s.ReconnectOptions | ||
if opts.attempt >= reconnOpts.MaxRetries { | ||
s.mu.Lock() | ||
s.err = fmt.Errorf("connection failed after %d attempts", reconnOpts.MaxRetries) | ||
s.mu.Unlock() | ||
s.Close() // Close the connection to unblock any readers. | ||
return | ||
samthanawalla marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
delay := calculateReconnectDelay(reconnOpts, opts.attempt) | ||
|
||
select { | ||
case <-s.done: | ||
return | ||
case <-time.After(delay): | ||
opts.attempt++ | ||
s.establishSSE(opts) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a recursive call to establishSSE, so in theory you could blow the stack. |
||
} | ||
} | ||
|
||
// calculateReconnectDelay calculates a delay using exponential backoff with full jitter. | ||
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration { | ||
// Calculate the exponential backoff using the grow factor. | ||
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt))) | ||
|
||
// Cap the backoffDuration at maxDelay. | ||
backoffDuration = min(backoffDuration, opts.maxDelay) | ||
|
||
// Use a full jitter using backoffDuration | ||
jitter := rand.N(backoffDuration) | ||
|
||
return backoffDuration + jitter | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goroutine is not present in the current code. Why is it needed here?
More general question: is there a description of the design of resumability somewhere? Not looking for a design doc, just a short sketch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the previous logic only covered section 2.1
Which established an SSE stream if the server responds to a client POST with a text/event-stream.
However in section 2.2, it says: The client MAY issue an HTTP GET to the MCP endpoint. This can be used to open an SSE stream, allowing the server to communicate to the client, without the client first sending data via HTTP POST
So from my understanding, we need to proactively issue a client GET request after initialization to see if the server has any SSE streams waiting before a client POST.
Added a comment.