Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions kotlin-sdk-client/api/kotlin-sdk-client.api
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp
protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V
public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V
public final fun callTool (Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun callTool (Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun callTool (Ljava/lang/String;Ljava/util/Map;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public final fun complete (Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun complete$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ import kotlinx.atomicfu.update
import kotlinx.collections.immutable.minus
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.collections.immutable.toPersistentSet
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.add
import kotlinx.serialization.json.buildJsonArray
import kotlinx.serialization.json.buildJsonObject
import kotlin.coroutines.cancellation.CancellationException

private val logger = KotlinLogging.logger {}
Expand Down Expand Up @@ -185,20 +189,14 @@ public open class Client(private val clientInfo: Implementation, options: Client
}
}

Method.Defined.PromptsGet,
Method.Defined.PromptsList,
Method.Defined.CompletionComplete,
-> {
Method.Defined.PromptsGet, Method.Defined.PromptsList, Method.Defined.CompletionComplete -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually better to have each case on a new line, because add/delete is more visible in diff

if (serverCapabilities?.prompts == null) {
throw IllegalStateException("Server does not support prompts (required for $method)")
}
}

Method.Defined.ResourcesList,
Method.Defined.ResourcesTemplatesList,
Method.Defined.ResourcesRead,
Method.Defined.ResourcesSubscribe,
Method.Defined.ResourcesUnsubscribe,
Method.Defined.ResourcesList, Method.Defined.ResourcesTemplatesList,
Method.Defined.ResourcesRead, Method.Defined.ResourcesSubscribe, Method.Defined.ResourcesUnsubscribe,
-> {
val resCaps = serverCapabilities?.resources
?: error("Server does not support resources (required for $method)")
Expand All @@ -210,17 +208,13 @@ public open class Client(private val clientInfo: Implementation, options: Client
}
}

Method.Defined.ToolsCall,
Method.Defined.ToolsList,
-> {
Method.Defined.ToolsCall, Method.Defined.ToolsList -> {
if (serverCapabilities?.tools == null) {
throw IllegalStateException("Server does not support tools (required for $method)")
}
}

Method.Defined.Initialize,
Method.Defined.Ping,
-> {
Method.Defined.Initialize, Method.Defined.Ping -> {
// No specific capability required
}

Expand Down Expand Up @@ -405,10 +399,14 @@ public open class Client(private val clientInfo: Implementation, options: Client
): EmptyRequestResult = request(request, options)

/**
* Calls a tool on the server by name, passing the specified arguments.
* Calls a tool on the server by name, passing the specified arguments and metadata.
*
* @param name The name of the tool to call.
* @param arguments A map of argument names to values for the tool.
* @param meta A map of metadata key-value pairs. Keys must follow MCP specification format.
* - Optional prefix: dot-separated labels followed by slash (e.g., "api.example.com/")
* - Name: alphanumeric start/end, may contain hyphens, underscores, dots, alphanumerics
* - Reserved prefixes starting with "mcp" or "modelcontextprotocol" are forbidden
* @param compatibility Whether to use compatibility mode for older protocol versions.
* @param options Optional request options.
* @return The result of the tool call, or `null` if none.
Expand All @@ -417,23 +415,19 @@ public open class Client(private val clientInfo: Implementation, options: Client
public suspend fun callTool(
name: String,
arguments: Map<String, Any?>,
meta: Map<String, Any?> = emptyMap(),
compatibility: Boolean = false,
options: RequestOptions? = null,
): CallToolResultBase? {
val jsonArguments = arguments.mapValues { (_, value) ->
when (value) {
is String -> JsonPrimitive(value)
is Number -> JsonPrimitive(value)
is Boolean -> JsonPrimitive(value)
is JsonElement -> value
null -> JsonNull
else -> JsonPrimitive(value.toString())
}
}
validateMetaKeys(meta.keys)

val jsonArguments = convertToJsonMap(arguments)
val jsonMeta = convertToJsonMap(meta)

val request = CallToolRequest(
name = name,
arguments = JsonObject(jsonArguments),
_meta = JsonObject(jsonMeta),
)
return callTool(request, compatibility, options)
}
Expand Down Expand Up @@ -588,4 +582,116 @@ public open class Client(private val clientInfo: Implementation, options: Client
val rootList = roots.value.values.toList()
return ListRootsResult(rootList)
}

/**
* Validates meta keys according to MCP specification.
*
* Key format: [prefix/]name
* - Prefix (optional): dot-separated labels + slash
* - Reserved prefixes contain "modelcontextprotocol" or "mcp" as complete labels
* - Name: alphanumeric start/end, may contain hyphens, underscores, dots (empty allowed)
*/
private fun validateMetaKeys(keys: Set<String>) {
val labelPattern = Regex("[a-zA-Z]([a-zA-Z0-9-]*[a-zA-Z0-9])?")
val namePattern = Regex("[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?")
Comment on lines +601 to +602
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regex patterns should be static constants


keys.forEach { key ->
require(key.isNotEmpty()) { "Meta key cannot be empty" }

val (prefix, name) = key.split('/', limit = 2).let { parts ->
when (parts.size) {
1 -> null to parts[0]
2 -> parts[0] to parts[1]
else -> throw IllegalArgumentException("Unexpected split result for key: $key")
}
}

// Validate prefix if present
prefix?.let {
require(it.isNotEmpty()) { "Invalid _meta key '$key': prefix cannot be empty" }

val labels = it.split('.')
require(labels.all { label -> label.matches(labelPattern) }) {
"Invalid _meta key '$key': prefix labels must start with a letter, end with letter/digit, and contain only letters, digits, or hyphens"
}

require(
labels.none { label ->
label.equals("modelcontextprotocol", ignoreCase = true) ||
label.equals("mcp", ignoreCase = true)
},
) {
"Invalid _meta key '$key': prefix cannot contain reserved labels 'modelcontextprotocol' or 'mcp'"
}
}

// Validate name (empty allowed)
require(name.isEmpty() || name.matches(namePattern)) {
"Invalid _meta key '$key': name must start and end with alphanumeric characters, and contain only alphanumerics, hyphens, underscores, or dots"
}
}
}

private fun convertToJsonMap(map: Map<String, Any?>): Map<String, JsonElement> = map.mapValues { (key, value) ->
try {
convertToJsonElement(value)
} catch (e: Exception) {
logger.warn { "Failed to convert value for key '$key': ${e.message}. Using string representation." }
JsonPrimitive(value.toString())
}
}

@OptIn(ExperimentalUnsignedTypes::class, ExperimentalSerializationApi::class)
private fun convertToJsonElement(value: Any?): JsonElement = when (value) {
Copy link
Contributor

@kpavlov kpavlov Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to extract JSON-related code to a separate class. It could be a separate PR with unit tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be moved to separate class

null -> JsonNull

is JsonElement -> value

is String -> JsonPrimitive(value)

is Number -> JsonPrimitive(value)

is Boolean -> JsonPrimitive(value)

is Char -> JsonPrimitive(value.toString())

is Enum<*> -> JsonPrimitive(value.name)

is Map<*, *> -> buildJsonObject { value.forEach { (k, v) -> put(k.toString(), convertToJsonElement(v)) } }

is Collection<*> -> buildJsonArray { value.forEach { add(convertToJsonElement(it)) } }

is Array<*> -> buildJsonArray { value.forEach { add(convertToJsonElement(it)) } }

// Primitive arrays
is IntArray -> buildJsonArray { value.forEach { add(it) } }

is LongArray -> buildJsonArray { value.forEach { add(it) } }

is FloatArray -> buildJsonArray { value.forEach { add(it) } }

is DoubleArray -> buildJsonArray { value.forEach { add(it) } }

is BooleanArray -> buildJsonArray { value.forEach { add(it) } }

is ShortArray -> buildJsonArray { value.forEach { add(it) } }

is ByteArray -> buildJsonArray { value.forEach { add(it) } }

is CharArray -> buildJsonArray { value.forEach { add(it.toString()) } }

// Unsigned arrays
is UIntArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } }

is ULongArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } }

is UShortArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } }

is UByteArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } }

else -> {
logger.debug { "Converting unknown type ${value::class} to string: $value" }
JsonPrimitive(value.toString())
}
}
}
Loading