diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index a785a916..46a7b60d 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -8,9 +8,9 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V public fun assertRequestHandlerCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V public final fun callTool (Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun callTool (Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun callTool (Ljava/lang/String;Ljava/util/Map;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CallToolRequest;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; - public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public static synthetic fun callTool$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Ljava/lang/String;Ljava/util/Map;Ljava/util/Map;ZLio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun complete (Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun complete$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/CompleteRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public fun connect (Lio/modelcontextprotocol/kotlin/sdk/shared/Transport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 95a5bc5b..56fd1caf 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -50,10 +50,14 @@ import kotlinx.atomicfu.update import kotlinx.collections.immutable.minus import kotlinx.collections.immutable.persistentMapOf import kotlinx.collections.immutable.toPersistentSet +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject import kotlin.coroutines.cancellation.CancellationException private val logger = KotlinLogging.logger {} @@ -210,17 +214,13 @@ public open class Client(private val clientInfo: Implementation, options: Client } } - Method.Defined.ToolsCall, - Method.Defined.ToolsList, - -> { + Method.Defined.ToolsCall, Method.Defined.ToolsList -> { if (serverCapabilities?.tools == null) { throw IllegalStateException("Server does not support tools (required for $method)") } } - Method.Defined.Initialize, - Method.Defined.Ping, - -> { + Method.Defined.Initialize, Method.Defined.Ping -> { // No specific capability required } @@ -405,10 +405,14 @@ public open class Client(private val clientInfo: Implementation, options: Client ): EmptyRequestResult = request(request, options) /** - * Calls a tool on the server by name, passing the specified arguments. + * Calls a tool on the server by name, passing the specified arguments and metadata. * * @param name The name of the tool to call. * @param arguments A map of argument names to values for the tool. + * @param meta A map of metadata key-value pairs. Keys must follow MCP specification format. + * - Optional prefix: dot-separated labels followed by slash (e.g., "api.example.com/") + * - Name: alphanumeric start/end, may contain hyphens, underscores, dots, alphanumerics + * - Reserved prefixes starting with "mcp" or "modelcontextprotocol" are forbidden * @param compatibility Whether to use compatibility mode for older protocol versions. * @param options Optional request options. * @return The result of the tool call, or `null` if none. @@ -417,23 +421,19 @@ public open class Client(private val clientInfo: Implementation, options: Client public suspend fun callTool( name: String, arguments: Map, + meta: Map = emptyMap(), compatibility: Boolean = false, options: RequestOptions? = null, ): CallToolResultBase? { - val jsonArguments = arguments.mapValues { (_, value) -> - when (value) { - is String -> JsonPrimitive(value) - is Number -> JsonPrimitive(value) - is Boolean -> JsonPrimitive(value) - is JsonElement -> value - null -> JsonNull - else -> JsonPrimitive(value.toString()) - } - } + validateMetaKeys(meta.keys) + + val jsonArguments = convertToJsonMap(arguments) + val jsonMeta = convertToJsonMap(meta) val request = CallToolRequest( name = name, arguments = JsonObject(jsonArguments), + _meta = JsonObject(jsonMeta), ) return callTool(request, compatibility, options) } @@ -588,4 +588,116 @@ public open class Client(private val clientInfo: Implementation, options: Client val rootList = roots.value.values.toList() return ListRootsResult(rootList) } + + /** + * Validates meta keys according to MCP specification. + * + * Key format: [prefix/]name + * - Prefix (optional): dot-separated labels + slash + * - Reserved prefixes contain "modelcontextprotocol" or "mcp" as complete labels + * - Name: alphanumeric start/end, may contain hyphens, underscores, dots (empty allowed) + */ + private fun validateMetaKeys(keys: Set) { + val labelPattern = Regex("[a-zA-Z]([a-zA-Z0-9-]*[a-zA-Z0-9])?") + val namePattern = Regex("[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?") + + keys.forEach { key -> + require(key.isNotEmpty()) { "Meta key cannot be empty" } + + val (prefix, name) = key.split('/', limit = 2).let { parts -> + when (parts.size) { + 1 -> null to parts[0] + 2 -> parts[0] to parts[1] + else -> throw IllegalArgumentException("Unexpected split result for key: $key") + } + } + + // Validate prefix if present + prefix?.let { + require(it.isNotEmpty()) { "Invalid _meta key '$key': prefix cannot be empty" } + + val labels = it.split('.') + require(labels.all { label -> label.matches(labelPattern) }) { + "Invalid _meta key '$key': prefix labels must start with a letter, end with letter/digit, and contain only letters, digits, or hyphens" + } + + require( + labels.none { label -> + label.equals("modelcontextprotocol", ignoreCase = true) || + label.equals("mcp", ignoreCase = true) + }, + ) { + "Invalid _meta key '$key': prefix cannot contain reserved labels 'modelcontextprotocol' or 'mcp'" + } + } + + // Validate name (empty allowed) + require(name.isEmpty() || name.matches(namePattern)) { + "Invalid _meta key '$key': name must start and end with alphanumeric characters, and contain only alphanumerics, hyphens, underscores, or dots" + } + } + } + + private fun convertToJsonMap(map: Map): Map = map.mapValues { (key, value) -> + try { + convertToJsonElement(value) + } catch (e: Exception) { + logger.warn { "Failed to convert value for key '$key': ${e.message}. Using string representation." } + JsonPrimitive(value.toString()) + } + } + + @OptIn(ExperimentalUnsignedTypes::class, ExperimentalSerializationApi::class) + private fun convertToJsonElement(value: Any?): JsonElement = when (value) { + null -> JsonNull + + is JsonElement -> value + + is String -> JsonPrimitive(value) + + is Number -> JsonPrimitive(value) + + is Boolean -> JsonPrimitive(value) + + is Char -> JsonPrimitive(value.toString()) + + is Enum<*> -> JsonPrimitive(value.name) + + is Map<*, *> -> buildJsonObject { value.forEach { (k, v) -> put(k.toString(), convertToJsonElement(v)) } } + + is Collection<*> -> buildJsonArray { value.forEach { add(convertToJsonElement(it)) } } + + is Array<*> -> buildJsonArray { value.forEach { add(convertToJsonElement(it)) } } + + // Primitive arrays + is IntArray -> buildJsonArray { value.forEach { add(it) } } + + is LongArray -> buildJsonArray { value.forEach { add(it) } } + + is FloatArray -> buildJsonArray { value.forEach { add(it) } } + + is DoubleArray -> buildJsonArray { value.forEach { add(it) } } + + is BooleanArray -> buildJsonArray { value.forEach { add(it) } } + + is ShortArray -> buildJsonArray { value.forEach { add(it) } } + + is ByteArray -> buildJsonArray { value.forEach { add(it) } } + + is CharArray -> buildJsonArray { value.forEach { add(it.toString()) } } + + // Unsigned arrays + is UIntArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } } + + is ULongArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } } + + is UShortArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } } + + is UByteArray -> buildJsonArray { value.forEach { add(JsonPrimitive(it)) } } + + else -> { + logger.debug { "Converting unknown type ${value::class} to string: $value" } + JsonPrimitive(value.toString()) + } + } } diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt new file mode 100644 index 00000000..e7061073 --- /dev/null +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt @@ -0,0 +1,274 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.boolean +import kotlinx.serialization.json.int +import kotlinx.serialization.json.jsonPrimitive +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +/** + * Comprehensive test suite for MCP Client meta parameter functionality + * + * Tests cover: + * - Meta key validation according to MCP specification + * - JSON type conversion for various data types + * - Error handling for invalid meta keys + * - Integration with callTool method + */ +class ClientMetaParameterTest { + + private lateinit var client: Client + private lateinit var mockTransport: MockTransport + private val clientInfo = Implementation("test-client", "1.0.0") + + @BeforeTest + fun setup() = runTest { + mockTransport = MockTransport() + client = Client(clientInfo = clientInfo) + mockTransport.setupInitializationResponse() + client.connect(mockTransport) + } + + @Test + fun `should accept valid meta keys without throwing exception`() = runTest { + val validMeta = buildMap { + put("simple-key", "value1") + put("api.example.com/version", "1.0") + put("com.company.app/setting", "enabled") + put("retry_count", 3) + put("user.preference", true) + put("valid123", "alphanumeric") + put("multi.dot.name", "multiple-dots") + put("under_score", "underscore") + put("hyphen-dash", "hyphen") + put("org.apache.kafka/consumer-config", "complex-valid-prefix") + } + + val result = runCatching { + client.callTool("test-tool", mapOf("arg" to "value"), validMeta) + } + + assertTrue(result.isSuccess, "Valid meta keys should not cause exceptions") + mockTransport.lastJsonRpcRequest()?.let { request -> + val params = request.params as JsonObject + assertTrue(params.containsKey("_meta"), "Request should contain _meta field") + val metaField = params["_meta"] as JsonObject + + // Verify all meta keys are present + assertEquals(validMeta.size, metaField.size, "All meta keys should be included") + + // Verify specific key-value pairs + assertEquals("value1", metaField["simple-key"]?.jsonPrimitive?.content) + assertEquals("1.0", metaField["api.example.com/version"]?.jsonPrimitive?.content) + assertEquals("enabled", metaField["com.company.app/setting"]?.jsonPrimitive?.content) + assertEquals(3, metaField["retry_count"]?.jsonPrimitive?.int) + assertEquals(true, metaField["user.preference"]?.jsonPrimitive?.boolean) + assertEquals("alphanumeric", metaField["valid123"]?.jsonPrimitive?.content) + assertEquals("multiple-dots", metaField["multi.dot.name"]?.jsonPrimitive?.content) + assertEquals("underscore", metaField["under_score"]?.jsonPrimitive?.content) + assertEquals("hyphen", metaField["hyphen-dash"]?.jsonPrimitive?.content) + assertEquals("complex-valid-prefix", metaField["org.apache.kafka/consumer-config"]?.jsonPrimitive?.content) + } + } + + @Test + fun `should accept edge case valid prefixes and names`() = runTest { + val edgeCaseValidMeta = buildMap { + put("a/", "single-char-prefix-empty-name") // empty name is allowed + put("a1-b2/test", "alphanumeric-hyphen-prefix") + put("long.domain.name.here/config", "long-prefix") + put("x/a", "minimal-valid-key") + put("test123", "alphanumeric-name-only") + } + + val result = runCatching { + client.callTool("test-tool", emptyMap(), edgeCaseValidMeta) + } + + assertTrue(result.isSuccess, "Edge case valid meta keys should be accepted") + } + + @Test + fun `should reject mcp reserved prefix`() = runTest { + val invalidMeta = mapOf("mcp/internal" to "value") + + val exception = assertFailsWith { + client.callTool("test-tool", emptyMap(), invalidMeta) + } + + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + + @Test + fun `should reject modelcontextprotocol reserved prefix`() = runTest { + val invalidMeta = mapOf("modelcontextprotocol/config" to "value") + + val exception = assertFailsWith { + client.callTool("test-tool", emptyMap(), invalidMeta) + } + + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + + @Test + fun `should reject nested reserved prefixes`() = runTest { + val invalidKeys = listOf( + "api.mcp.io/setting", + "com.modelcontextprotocol.test/value", + "example.mcp/data", + "subdomain.mcp.com/config", + "app.modelcontextprotocol.dev/setting", + "test.mcp/value", + "service.modelcontextprotocol/data", + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject nested reserved key: $key", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should reject case-insensitive reserved prefixes`() = runTest { + val invalidKeys = listOf( + "MCP/internal", + "Mcp/config", + "mCp/setting", + "MODELCONTEXTPROTOCOL/data", + "ModelContextProtocol/value", + "modelContextProtocol/test", + ) + + invalidKeys.forEach { key -> + val exception = assertFailsWith( + message = "Should reject case-insensitive reserved key: $key", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + assertContains( + charSequence = exception.message ?: "", + other = "Invalid _meta key", + ) + } + } + + @Test + fun `should reject invalid key formats`() = runTest { + val invalidKeys = listOf( + "", // empty key - not allowed at key level + "/invalid", // starts with slash + "-invalid", // starts with hyphen + ".invalid", // starts with dot + "in valid", // contains space + "api../test", // consecutive dots + "api./test", // label ends with dot + ) + + invalidKeys.forEach { key -> + assertFailsWith( + message = "Should reject invalid key format: '$key'", + ) { + client.callTool("test-tool", emptyMap(), mapOf(key to "value")) + } + } + } + + @Test + fun `should convert various data types to JSON correctly`() = runTest { + val complexMeta = createComplexMetaData() + + val result = runCatching { + client.callTool( + "test-tool", + emptyMap(), + complexMeta, + ) + } + + assertTrue(result.isSuccess, "Complex data type conversion should not throw exceptions") + + mockTransport.lastJsonRpcRequest()?.let { request -> + assertEquals("tools/call", request.method) + val params = request.params as JsonObject + assertTrue(params.containsKey("_meta"), "Request should contain _meta field") + } + } + + @Test + fun `should handle nested map structures correctly`() = runTest { + val nestedMeta = buildNestedConfiguration() + + val result = runCatching { + client.callTool("test-tool", emptyMap(), nestedMeta) + } + + assertTrue(result.isSuccess) + + mockTransport.lastJsonRpcRequest()?.let { request -> + val params = request.params as JsonObject + val metaField = params["_meta"] as JsonObject + assertTrue(metaField.containsKey("config")) + } + } + + @Test + fun `should include empty meta object when meta parameter not provided`() = runTest { + client.callTool("test-tool", mapOf("arg" to "value")) + + mockTransport.lastJsonRpcRequest()?.let { request -> + val params = request.params as JsonObject + val metaField = params["_meta"] as JsonObject + assertTrue(metaField.isEmpty(), "Meta field should be empty when not provided") + } + } + + private fun createComplexMetaData(): Map = buildMap { + put("string", "text") + put("number", 42) + put("boolean", true) + put("null_value", null) + put("list", listOf(1, 2, 3)) + put("map", mapOf("nested" to "value")) + put("enum", "STRING") + put("int_array", intArrayOf(1, 2, 3)) + } + + private fun buildNestedConfiguration(): Map = buildMap { + put( + "config", + buildMap { + put( + "database", + buildMap { + put("host", "localhost") + put("port", 5432) + }, + ) + put("features", listOf("feature1", "feature2")) + }, + ) + } +} + +suspend fun MockTransport.lastJsonRpcRequest(): JSONRPCRequest? = getSentMessages().lastOrNull() as? JSONRPCRequest diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt new file mode 100644 index 00000000..c987619d --- /dev/null +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt @@ -0,0 +1,94 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.InitializeResult +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +class MockTransport : Transport { + private val _sentMessages = mutableListOf() + private val _receivedMessages = mutableListOf() + private val mutex = Mutex() + + suspend fun getSentMessages() = mutex.withLock { _sentMessages.toList() } + suspend fun getReceivedMessages() = mutex.withLock { _receivedMessages.toList() } + + private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null + private var onCloseBlock: (() -> Unit)? = null + private var onErrorBlock: ((Throwable) -> Unit)? = null + + override suspend fun start() = Unit + + override suspend fun send(message: JSONRPCMessage) { + mutex.withLock { + _sentMessages += message + } + + // Auto-respond to initialization and tool calls + when (message) { + is JSONRPCRequest -> { + when (message.method) { + "initialize" -> { + val initResponse = JSONRPCResponse( + id = message.id, + result = InitializeResult( + protocolVersion = "2024-11-05", + capabilities = ServerCapabilities( + tools = ServerCapabilities.Tools(listChanged = null), + ), + serverInfo = Implementation("mock-server", "1.0.0"), + ), + ) + onMessageBlock?.invoke(initResponse) + } + + "tools/call" -> { + val toolResponse = JSONRPCResponse( + id = message.id, + result = CallToolResult( + content = listOf(), + isError = false, + ), + ) + onMessageBlock?.invoke(toolResponse) + } + } + } + + else -> { + // Handle other message types if needed + } + } + } + + override suspend fun close() { + onCloseBlock?.invoke() + } + + override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { + onMessageBlock = { message -> + mutex.withLock { + _receivedMessages += message + } + block(message) + } + } + + override fun onClose(block: () -> Unit) { + onCloseBlock = block + } + + override fun onError(block: (Throwable) -> Unit) { + onErrorBlock = block + } + + fun setupInitializationResponse() { + // This method helps set up the mock for proper initialization + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt index 3b0de299..a9a8f278 100644 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt @@ -484,7 +484,7 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() { "result" : 11.0, "formattedResult" : "11,000", "precision" : 3, - "tags" : [ ] + "tags" : ["test", "calculator", "integration"] } """.trimIndent()