Skip to content

Enforce String or Integer non-null ID for MCP JSON-RPC Messages with a new McpSchema subclass #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
import io.modelcontextprotocol.spec.McpSchema.MessageId;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -161,8 +162,8 @@ void testBuilderPattern() {
@Test
void testMessageProcessing() {
// Create a test message
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
Map.of("key", "value"));
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method",
MessageId.of("test-id"), Map.of("key", "value"));

// Simulate receiving the message
transport.simulateMessageEvent("""
Expand Down Expand Up @@ -192,8 +193,8 @@ void testResponseMessageProcessing() {
""");

// Create and send a request message
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
Map.of("key", "value"));
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method",
MessageId.of("test-id"), Map.of("key", "value"));

// Verify message handling
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
Expand All @@ -216,8 +217,8 @@ void testErrorMessageProcessing() {
""");

// Create and send a request message
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
Map.of("key", "value"));
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method",
MessageId.of("test-id"), Map.of("key", "value"));

// Verify message handling
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
Expand Down Expand Up @@ -246,8 +247,8 @@ void testGracefulShutdown() {
StepVerifier.create(transport.closeGracefully()).verifyComplete();

// Create a test message
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
Map.of("key", "value"));
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method",
MessageId.of("test-id"), Map.of("key", "value"));

// Verify message is not processed after shutdown
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
Expand Down Expand Up @@ -292,10 +293,10 @@ void testMultipleMessageProcessing() {
""");

// Create and send corresponding messages
JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1",
JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", MessageId.of("id1"),
Map.of("key", "value1"));

JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2",
JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", MessageId.of("id2"),
Map.of("key", "value2"));

// Verify both messages are processed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.modelcontextprotocol.spec;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.spec.McpSchema.MessageId;
import io.modelcontextprotocol.util.Assert;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
Expand Down Expand Up @@ -47,7 +48,7 @@ public class McpClientSession implements McpSession {
private final McpClientTransport transport;

/** Map of pending responses keyed by request ID */
private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
private final ConcurrentHashMap<MessageId, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();

/** Map of request handlers keyed by method name */
private final ConcurrentHashMap<String, RequestHandler<?>> requestHandlers = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -231,10 +232,10 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
/**
* Generates a unique request ID in a non-blocking way. Combines a session-specific
* prefix with an atomic counter to ensure uniqueness.
* @return A unique request ID string
* @return A unique request ID
*/
private String generateRequestId() {
return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement();
private MessageId generateRequestId() {
return MessageId.of(this.sessionPrefix + "-" + this.requestCounter.getAndIncrement());
}

/**
Expand All @@ -247,7 +248,7 @@ private String generateRequestId() {
*/
@Override
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
String requestId = this.generateRequestId();
MessageId requestId = this.generateRequestId();

return Mono.deferContextual(ctx -> Mono.<McpSchema.JSONRPCResponse>create(pendingResponseSink -> {
logger.debug("Sending message for method {}", method);
Expand Down
117 changes: 115 additions & 2 deletions mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -21,8 +22,18 @@
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.JsonTypeInfo.As;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;

import io.modelcontextprotocol.util.Assert;

Expand All @@ -35,6 +46,7 @@
* @author Christian Tzolov
* @author Luca Chang
* @author Surbhi Bansal
* @author Zachary German
*/
public final class McpSchema {

Expand Down Expand Up @@ -144,6 +156,107 @@ public static final class ErrorCodes {

}

/**
* MCP JSON-RPC Message ID wrapper: MUST be non-null String or Int.
* <p>
* <b>Note</b>: This does <b>not</b> follow the JSON-RPC 'id' specification, which is
* nullable and could be a floating-point.
* </p>
*/
@JsonSerialize(using = MessageId.Serializer.class)
@JsonDeserialize(using = MessageId.Deserializer.class)
public static final class MessageId {

private final Object value;

public MessageId(String value) {
this.value = Objects.requireNonNull(value, "'id' must not be null");
}

public MessageId(Integer value) {
this.value = Objects.requireNonNull(value, "'id' must not be null");
}

public static MessageId of(Object raw) {
if (raw instanceof String s)
return new MessageId(s);
if (raw instanceof Integer i)
return new MessageId(i);
throw new IllegalArgumentException("MCP 'id' must be String or Integer");
}

public boolean isString() {
return value instanceof String;
}

public boolean isInteger() {
return value instanceof Integer;
}

public String asString() {
return (String) value;
}

public Integer asInteger() {
return (Integer) value;
}

public Object raw() {
return value;
}

@Override
public String toString() {
return String.valueOf(value);
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
MessageId messageId = (MessageId) o;
return value.equals(messageId.value);
}

@Override
public int hashCode() {
return value.hashCode();
}

public static class Deserializer extends JsonDeserializer<MessageId> {

@Override
public MessageId deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
JsonToken t = p.getCurrentToken();
if (t == JsonToken.VALUE_STRING) {
return new MessageId(p.getText());
}
else if (t == JsonToken.VALUE_NUMBER_INT) {
return new MessageId(p.getIntValue());
}
throw JsonMappingException.from(p, "MCP 'id' must be a non-null String or Integer");
}

}

public static class Serializer extends JsonSerializer<MessageId> {

@Override
public void serialize(MessageId id, JsonGenerator gen, SerializerProvider serializers) throws IOException {
if (id.isString()) {
gen.writeString(id.asString());
}
else {
gen.writeNumber(id.asInteger());
}
}

}

}

public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest,
CompleteRequest, GetPromptRequest, PaginatedRequest, ReadResourceRequest {

Expand Down Expand Up @@ -216,7 +329,7 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati
public record JSONRPCRequest( // @formatter:off
@JsonProperty("jsonrpc") String jsonrpc,
@JsonProperty("method") String method,
@JsonProperty("id") Object id,
@JsonProperty("id") MessageId id,
@JsonProperty("params") Object params) implements JSONRPCMessage { // @formatter:on
}

Expand Down Expand Up @@ -251,7 +364,7 @@ public record JSONRPCNotification( // @formatter:off
// @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
public record JSONRPCResponse( // @formatter:off
@JsonProperty("jsonrpc") String jsonrpc,
@JsonProperty("id") Object id,
@JsonProperty("id") MessageId id,
@JsonProperty("result") Object result,
@JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { // @formatter:on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.spec.McpSchema.MessageId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand All @@ -23,7 +24,7 @@ public class McpServerSession implements McpSession {

private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class);

private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
private final ConcurrentHashMap<MessageId, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();

private final String id;

Expand Down Expand Up @@ -104,13 +105,13 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl
this.clientInfo.lazySet(clientInfo);
}

private String generateRequestId() {
return this.id + "-" + this.requestCounter.getAndIncrement();
private MessageId generateRequestId() {
return MessageId.of(this.id + "-" + this.requestCounter.getAndIncrement());
}

@Override
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
String requestId = this.generateRequestId();
MessageId requestId = this.generateRequestId();

return Mono.<McpSchema.JSONRPCResponse>create(sink -> {
this.pendingResponses.put(requestId, sink);
Expand Down
Loading