diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java new file mode 100644 index 00000000..28bb1fe8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/UDSClientTransportProvider.java @@ -0,0 +1,185 @@ +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.time.Duration; +import java.util.concurrent.Executors; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSClientNonBlockingSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSClientTransportProvider implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(UDSClientTransportProvider.class); + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private ObjectMapper objectMapper; + + private UDSClientNonBlockingSocketChannel clientChannel; + + private UnixDomainSocketAddress targetAddress; + + private Scheduler outboundScheduler; + + private volatile boolean isClosing = false; + + public UDSClientTransportProvider(UnixDomainSocketAddress targetAddress) throws IOException { + this(new ObjectMapper(), targetAddress); + } + + public UDSClientTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress targetAddress) + throws IOException { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + this.objectMapper = objectMapper; + + // Start threads + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), "outbound"); + this.clientChannel = new UDSClientNonBlockingSocketChannel(); + this.targetAddress = targetAddress; + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.fromRunnable(() -> { + handleIncomingMessages(handler); + try { + this.clientChannel.connectBlocking(targetAddress, (client) -> { + logger.info("CONNECTED to targetAddress=" + targetAddress); + }, (data) -> { + JSONRPCMessage json = McpSchema.deserializeJsonRpcMessage(this.objectMapper, data); + if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + if (!isClosing) { + logger.error("Failed to enqueue inbound message: {}", json); + } + } + }); + } + catch (IOException e) { + this.clientChannel.close(); + throw new RuntimeException( + "Connect to address=" + targetAddress + " failed message: " + e.getMessage()); + } + startOutboundProcessing(); + }).subscribeOn(Schedulers.boundedElastic()); + } + + private void handleIncomingMessages(Function, Mono> inboundMessageHandler) { + this.inboundSink.asFlux() + .flatMap(message -> Mono.just(message) + .transform(inboundMessageHandler) + .contextWrite(ctx -> ctx.put("observation", "myObservation"))) + .subscribe(); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + if (this.outboundSink.tryEmitNext(message).isSuccess()) { + // TODO: essentially we could reschedule ourselves in some time and make + // another attempt with the already read data but pause reading until + // success + // In this approach we delegate the retry and the backpressure onto the + // caller. This might be enough for most cases. + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + } + + private void startOutboundProcessing() { + this.handleOutbound(messages -> messages + // this bit is important since writes come from user threads, and we + // want to ensure that the actual writing happens on a dedicated thread + .publishOn(outboundScheduler) + .handle((message, s) -> { + if (message != null && !isClosing) { + try { + this.clientChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + s.next(message); + } + catch (IOException e) { + s.error(new RuntimeException(e)); + } + } + })); + } + + protected void handleOutbound(Function, Flux> outboundConsumer) { + outboundConsumer.apply(outboundSink.asFlux()).doOnComplete(() -> { + isClosing = true; + outboundSink.tryEmitComplete(); + }).doOnError(e -> { + if (!isClosing) { + logger.error("Error in outbound processing", e); + isClosing = true; + outboundSink.tryEmitComplete(); + } + }).subscribe(); + } + + /** + * Gracefully closes the transport by destroying the process and disposing of the + * schedulers. This method sends a TERM signal to the process and waits for it to exit + * before cleaning up resources. + * @return A Mono that completes when the transport is closed + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing = true; + logger.debug("Initiating graceful shutdown"); + }).then(Mono.defer(() -> { + // First complete all sinks to stop accepting new messages + inboundSink.tryEmitComplete(); + outboundSink.tryEmitComplete(); + // Give a short time for any pending messages to be processed + return Mono.delay(Duration.ofMillis(100)).then(); + })).then(Mono.defer(() -> { + // Close our clientChannel + if (this.clientChannel != null) { + this.clientChannel.close(); + this.clientChannel = null; + } + return Mono.empty(); + })).then(Mono.fromRunnable(() -> { + try { + // The Threads are blocked on readLine so disposeGracefully would not + // interrupt them, therefore we issue an async hard dispose. + outboundScheduler.dispose(); + + logger.debug("Graceful shutdown completed"); + } + catch (Exception e) { + logger.error("Error during graceful shutdown", e); + } + })).then().subscribeOn(Schedulers.boundedElastic()); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java new file mode 100644 index 00000000..977f5b90 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/UDSServerTransportProvider.java @@ -0,0 +1,242 @@ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.UDSServerNonBlockingSocketChannel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class UDSServerTransportProvider implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(UDSServerTransportProvider.class); + + private final ObjectMapper objectMapper; + + private McpServerSession session; + + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + private final Sinks.One inboundReady = Sinks.one(); + + private UDSServerNonBlockingSocketChannel serverSocketChannel; + + private UnixDomainSocketAddress address; + + private UDSMcpSessionTransport transport; + + public UDSServerTransportProvider(UnixDomainSocketAddress unixSocketAddress) { + this(new ObjectMapper(), unixSocketAddress); + } + + public UDSServerTransportProvider(ObjectMapper objectMapper, UnixDomainSocketAddress unixSocketAddress) { + Assert.notNull(objectMapper, "The ObjectMapper can not be null"); + this.objectMapper = objectMapper; + this.address = unixSocketAddress; + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.transport = new UDSMcpSessionTransport(); + this.session = sessionFactory.create(transport); + this.transport.handleIncomingMessages(); + if (this.transport.isStarted.compareAndSet(false, true)) { + inboundReady.tryEmitValue(null); + } + // Also start listening for accept + try { + this.serverSocketChannel = new UDSServerNonBlockingSocketChannel(); + this.serverSocketChannel.start(this.address, (clientChannel) -> { + if (logger.isDebugEnabled()) { + logger.debug("Accepted connect from clientChannel=" + clientChannel); + } + // Start outbound processing now that the clientChannel has been accepted + this.transport.startOutboundProcessing(); + }, (dataLine) -> { + String message = (String) dataLine; + if (logger.isDebugEnabled()) { + logger.debug("Received message line=" + message); + } + try { + this.transport + .handleMessage(McpSchema.deserializeJsonRpcMessage(this.objectMapper, message.trim())); + } + catch (IOException e) { + this.serverSocketChannel.close(); + } + }); + } + catch (IOException e) { + // If this happens then we are doomed + this.serverSocketChannel.close(); + throw new RuntimeException("accepterNonBlockSocketChannel could not be started"); + } + } + + @Override + public Mono notifyClients(String method, Object params) { + if (this.session == null) { + return Mono.error(new McpError("No session to close")); + } + return this.session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send notification: {}", e.getMessage())); + } + + @Override + public Mono closeGracefully() { + if (this.session == null) { + return Mono.empty(); + } + return this.session.closeGracefully(); + } + + /** + * Implementation of McpServerTransport for the stdio session. + */ + private class UDSMcpSessionTransport implements McpServerTransport { + + private final Sinks.Many inboundSink; + + private final Sinks.Many outboundSink; + + private final AtomicBoolean isStarted = new AtomicBoolean(false); + + /** Scheduler for handling outbound messages */ + private Scheduler outboundScheduler; + + private final Sinks.One outboundReady = Sinks.one(); + + public UDSMcpSessionTransport() { + + this.inboundSink = Sinks.many().unicast().onBackpressureBuffer(); + this.outboundSink = Sinks.many().unicast().onBackpressureBuffer(); + + this.outboundScheduler = Schedulers.fromExecutorService(Executors.newSingleThreadExecutor(), + "uds-outbound"); + } + + public void handleMessage(McpSchema.JSONRPCMessage json) throws IOException { + try { + if (!this.inboundSink.tryEmitNext(json).isSuccess()) { + throw new Exception("Failed to enqueue message"); + } + } + catch (Exception e) { + logIfNotClosing("Error processing inbound message", e); + throw new IOException("Error in processing inbound message", e); + } + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + + return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { + if (outboundSink.tryEmitNext(message).isSuccess()) { + return Mono.empty(); + } + else { + return Mono.error(new RuntimeException("Failed to enqueue message")); + } + })); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + isClosing.set(true); + logger.debug("Session transport closing gracefully"); + inboundSink.tryEmitComplete(); + }); + } + + @Override + public void close() { + isClosing.set(true); + logger.debug("Session transport closed"); + } + + private void handleIncomingMessages() { + this.inboundSink.asFlux().flatMap(message -> session.handle(message)).doOnTerminate(() -> { + // The outbound processing will dispose its scheduler upon completion + this.outboundSink.tryEmitComplete(); + // this.inboundScheduler.dispose(); + }).subscribe(); + } + + /** + * Starts the outbound processing thread that writes JSON-RPC messages to stdout. + * Messages are serialized to JSON and written with a newline delimiter. + */ + private void startOutboundProcessing() { + Function, Flux> outboundConsumer = messages -> messages // @formatter:off + .doOnSubscribe(subscription -> outboundReady.tryEmitValue(null)) + .publishOn(outboundScheduler) + .handle((message, sink) -> { + if (message != null && !isClosing.get()) { + try { + serverSocketChannel.writeMessageBlocking(objectMapper.writeValueAsString(message)); + sink.next(message); + } + catch (IOException e) { + if (!isClosing.get()) { + logger.error("Error writing message", e); + sink.error(new RuntimeException(e)); + } + else { + logger.debug("Stream closed during shutdown", e); + } + } + } + else if (isClosing.get()) { + sink.complete(); + } + }) + .doOnComplete(() -> { + isClosing.set(true); + outboundScheduler.dispose(); + }) + .doOnError(e -> { + if (!isClosing.get()) { + logger.error("Error in outbound processing", e); + isClosing.set(true); + outboundScheduler.dispose(); + } + }) + .map(msg -> (JSONRPCMessage) msg); + + outboundConsumer.apply(outboundSink.asFlux()).subscribe(); + } // @formatter:on + + private void logIfNotClosing(String message, Exception e) { + if (!isClosing.get()) { + logger.error(message, e); + } + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java new file mode 100644 index 00000000..33c7f5f7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ClientNonBlockingSocketChannel.java @@ -0,0 +1,66 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientNonBlockingSocketChannel extends NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ClientNonBlockingSocketChannel.class); + + private SocketChannel client; + + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(StandardProtocolFamily protocol, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + if (this.client != null) { + throw new IOException("Already connected"); + } + this.client = connectBlocking(SocketChannel.open(protocol), address, connectHandler, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + @Override + public void close() { + hardCloseClient(this.client, (client) -> { + this.client = null; + }); + } + + public void writeMessageBlocking(String message) throws IOException { + if (this.client == null) { + throw new IOException("Cannot write until client connected"); + } + writeBlocking(client, message); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java new file mode 100644 index 00000000..b1186e3c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ClientNonBlockingSocketChannel.java @@ -0,0 +1,35 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet4ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public Inet4ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet4ClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(Inet4Address address, int port, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.INET, new InetSocketAddress(address, port), connectHandler, + readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java new file mode 100644 index 00000000..a4b9c61f --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet4ServerNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet4ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public Inet4ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet4ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(Inet4Address address, int port, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.INET, new InetSocketAddress(address, port), acceptHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java new file mode 100644 index 00000000..9af48485 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ClientNonBlockingSocketChannel.java @@ -0,0 +1,35 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet6ClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public Inet6ClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet6ClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(Inet6Address address, int port, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), connectHandler, + readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java new file mode 100644 index 00000000..8a1a95e2 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Inet6ServerNonBlockingSocketChannel.java @@ -0,0 +1,34 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetSocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class Inet6ServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public Inet6ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public Inet6ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(Inet6Address address, int port, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.INET6, new InetSocketAddress(address, port), acceptHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java new file mode 100644 index 00000000..de3fe8ba --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/NonBlockingSocketChannel.java @@ -0,0 +1,397 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(NonBlockingSocketChannel.class); + + public static final int DEFAULT_INBUFFER_SIZE = 1024; + + protected static String MESSAGE_DELIMITER = "\n"; + + protected static int BLOCKING_WRITE_TIMEOUT = 5000; + + protected static int BLOCKING_CONNECT_TIMEOUT = 10000; + + protected final Selector selector; + + protected final ByteBuffer inBuffer; + + protected final ExecutorService executor; + + @FunctionalInterface + public interface IOConsumer { + + void apply(T t) throws IOException; + + } + + protected class AttachedIO { + + public ByteBuffer writing; + + public StringBuffer reading; + + } + + public NonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + Assert.notNull(selector, "Selector must not be null"); + this.selector = selector; + this.inBuffer = ByteBuffer.allocate(incomingBufferSize); + this.executor = (executor == null) ? Executors.newSingleThreadExecutor() : executor; + } + + public NonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + this(selector, incomingBufferSize, null); + } + + public NonBlockingSocketChannel(Selector selector) { + this(selector, DEFAULT_INBUFFER_SIZE); + } + + public NonBlockingSocketChannel() throws IOException { + this(Selector.open()); + } + + protected Runnable getRunnableForProcessing(IOConsumer acceptHandler, + IOConsumer connectHandler, IOConsumer readHandler) { + return () -> { + SelectionKey key = null; + try { + while (true) { + this.selector.select(); + Set selectedKeys = selector.selectedKeys(); + Iterator iter = selectedKeys.iterator(); + while (iter.hasNext()) { + key = iter.next(); + if (key.isConnectable()) { + handleConnectable(key, connectHandler); + } + else if (key.isAcceptable()) { + handleAcceptable(key, acceptHandler); + } + else if (key.isReadable()) { + handleReadable(key, readHandler); + } + else if (key.isWritable()) { + handleWritable(key); + } + iter.remove(); + } + } + } + catch (Exception e) { + handleException(key, e); + } + }; + } + + public abstract void close(); + + protected abstract void handleException(SelectionKey key, Exception e); + + protected void start(IOConsumer acceptHandler, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + this.executor.execute(getRunnableForProcessing(acceptHandler, connectHandler, readHandler)); + } + + // For client subclasses + protected void handleConnectable(SelectionKey key, IOConsumer connectHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("handleConnectable client=" + client.getRemoteAddress()); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + if (client.isConnectionPending()) { + client.finishConnect(); + if (logger.isDebugEnabled()) { + logger.debug("handleConnectable FINISHED"); + } + } + if (connectHandler != null) { + connectHandler.apply(client); + } + } + } + + protected void handleAcceptable(SelectionKey key, IOConsumer acceptHandler) throws IOException { + ServerSocketChannel serverSocket = (ServerSocketChannel) key.channel(); + SocketChannel client = serverSocket.accept(); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("handleAcceptable client=" + client); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(this.selector, SelectionKey.OP_READ, new AttachedIO()); + configureAcceptSocketChannel(client); + if (client.isConnectionPending()) { + client.finishConnect(); + if (logger.isDebugEnabled()) { + logger.debug("handleAcceptable FINISHED"); + } + } + if (acceptHandler != null) { + acceptHandler.apply(client); + } + } + } + + protected void configureAcceptSocketChannel(SocketChannel client) throws IOException { + // Subclasses may override + } + + protected AttachedIO getAttachedIO(SelectionKey key) throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + if (io == null) { + throw new IOException("No AttachedIO object found on key"); + } + return io; + } + + protected void handleReadable(SelectionKey key, IOConsumer readHandler) throws IOException { + SocketChannel client = (SocketChannel) key.channel(); + Object lock = client.blockingLock(); + AttachedIO io = getAttachedIO(key); + if (logger.isDebugEnabled()) { + logger.debug("handleReadable client=" + client); + } + synchronized (lock) { + // non-blocking read here + int r = client.read(this.inBuffer); + // Check if we should expect any more reads + if (r == -1) { + throw new IOException("Channel read reached end of stream"); + } + this.inBuffer.flip(); + String partial = new String(this.inBuffer.array(), 0, r, StandardCharsets.UTF_8); + // If there is are previous partial, then get the io.reading string Buffer + StringBuffer sb = (io.reading != null) ? (StringBuffer) io.reading : new StringBuffer(); + // And append the just read partial to the string buffer + sb.append(partial); + if (partial.endsWith(MESSAGE_DELIMITER)) { + // Get the entire message from the string buffer + String message = sb.toString(); + // Set the io.reading value to null as we are done with this message + io.reading = null; + if (logger.isDebugEnabled()) { + logger.debug("handleReadable COMPLETE msg=" + message); + } + if (readHandler != null) { + readHandler.apply(message); + } + } + else { + io.reading = sb; + if (logger.isDebugEnabled()) { + logger.debug("handleReadable PARTIAL msg=" + partial); + } + } + } + // Clear inbuffer for next read + this.inBuffer.clear(); + } + + protected void handleWritable(SelectionKey key) throws IOException { + ByteBuffer buf = getAttachedIO(key).writing; + SocketChannel client = (SocketChannel) key.channel(); + if (buf != null) { + doWrite(key, client, buf, (lock) -> { + synchronized (lock) { + if (logger.isDebugEnabled()) { + logger.debug("handleWritable NOTIFY client=" + client); + } + lock.notifyAll(); + } + }); + } + } + + protected void doWrite(SocketChannel client, String message, IOConsumer writeHandler) throws IOException { + Assert.notNull(client, "Client must not be null"); + Assert.notNull(message, "Message must not be null"); + if (logger.isDebugEnabled()) { + logger.debug("doWrite msg=" + message); + } + doWrite(client.keyFor(this.selector), client, ByteBuffer.wrap(message.getBytes(StandardCharsets.UTF_8)), + writeHandler); + } + + protected void doWrite(SelectionKey key, SocketChannel client, ByteBuffer buf, IOConsumer writeHandler) + throws IOException { + AttachedIO io = (AttachedIO) key.attachment(); + Object lock = client.blockingLock(); + synchronized (lock) { + int written = client.write(buf); + if (buf.hasRemaining()) { + if (logger.isDebugEnabled()) { + logger.debug("doWrite PARTIAL written=" + written + " remaining=" + buf.remaining()); + } + io.writing = buf.slice(); + key.interestOpsOr(SelectionKey.OP_WRITE); + } + else { + if (logger.isDebugEnabled()) { + logger.debug("doWrite COMPLETED msg=" + new String(buf.array(), 0, written)); + } + io.writing = null; + key.interestOps(SelectionKey.OP_READ); + if (writeHandler != null) { + writeHandler.apply(lock); + } + } + } + } + + protected void executorShutdown() { + if (!this.executor.isShutdown()) { + if (logger.isDebugEnabled()) { + logger.debug("executorShutdown"); + } + try { + this.executor.awaitTermination(2000, TimeUnit.MILLISECONDS); + this.executor.shutdown(); + } + catch (InterruptedException e) { + if (logger.isDebugEnabled()) { + logger.debug("Exception in executor awaitTermination", e); + } + } + } + } + + protected void hardCloseClient(SocketChannel client, IOConsumer closeHandler) { + if (client != null) { + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("hardCloseClient client=" + client); + } + synchronized (lock) { + try { + if (closeHandler != null) { + closeHandler.apply(client); + } + client.close(); + client = null; + } + catch (IOException e) { + if (logger.isDebugEnabled()) { + logger.debug("hardClose client socketchannel.close exception", e); + } + } + } + executorShutdown(); + } + } + + protected void writeBlocking(SocketChannel client, String message) throws IOException { + Objects.requireNonNull(client, "Client must not be null"); + Objects.requireNonNull(message, "Message must not be null"); + // Escape any embedded newlines in the JSON message, and add newline + String outputMessage = message.replace("\r\n", "\\n") + .replace("\n", "\\n") + .replace("\r", "\\n") + .concat(MESSAGE_DELIMITER); + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("writeBlocking msg=" + outputMessage); + } + synchronized (lock) { + // do the non blocking write in thread while holding lock. + doWrite(client, outputMessage, null); + ByteBuffer bufRemaining = null; + long waitTime = System.currentTimeMillis() + BLOCKING_WRITE_TIMEOUT; + while (waitTime - System.currentTimeMillis() > 0) { + // Before releasing lock, check for writing buffer remaining + bufRemaining = getAttachedIO(client.keyFor(this.selector)).writing; + if (bufRemaining == null || bufRemaining.remaining() == 0) { + // It's done + break; + } + // If write is *not* completed, then wait timeout /10 + try { + if (logger.isDebugEnabled()) { + logger + .debug("writeBlocking WAITING=" + String.valueOf(waitTime / 10) + " msg=" + outputMessage); + } + lock.wait(waitTime / 10); + } + catch (InterruptedException e) { + throw new InterruptedIOException("write message wait interrupted"); + } + } + if (bufRemaining != null && bufRemaining.remaining() > 0) { + throw new IOException("Write not completed. Non empty buffer remaining after timeout"); + } + } + if (logger.isDebugEnabled()) { + logger.debug("writeBlocking COMPLETED msg=" + outputMessage); + } + } + + protected void configureConnectSocketChannel(SocketChannel client, SocketAddress connectAddress) + throws IOException { + // Subclasses may override + } + + protected SocketChannel connectBlocking(SocketChannel client, SocketAddress address, + IOConsumer connectHandler, IOConsumer readHandler) throws IOException { + Object lock = client.blockingLock(); + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking CONNECTING targetAddress=" + address); + } + synchronized (lock) { + client.configureBlocking(false); + client.register(selector, SelectionKey.OP_CONNECT); + configureConnectSocketChannel(client, address); + // Start the read thread before connect + // No/null accept handler for clients + start(null, (c) -> { + synchronized (lock) { + if (connectHandler != null) { + connectHandler.apply(c); + } + lock.notifyAll(); + } + }, readHandler); + + client.connect(address); + + try { + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking WAITING targetAddress=" + address); + } + lock.wait(BLOCKING_CONNECT_TIMEOUT); + } + catch (InterruptedException e) { + throw new IOException("Connect to address=" + address + " timed out"); + } + if (logger.isDebugEnabled()) { + logger.debug("connectBlocking CONNECTED client=" + client.getLocalAddress() + " connecting=" + address); + } + return client; + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java new file mode 100644 index 00000000..4c64e3d1 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/ServerNonBlockingSocketChannel.java @@ -0,0 +1,89 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.SocketAddress; +import java.net.StandardProtocolFamily; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ServerNonBlockingSocketChannel extends NonBlockingSocketChannel { + + private static final Logger logger = LoggerFactory.getLogger(ServerNonBlockingSocketChannel.class); + + protected SocketChannel acceptedClient; + + public ServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public ServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public ServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + protected void configureServerSocketChannel(ServerSocketChannel serverSocketChannel, SocketAddress acceptAddress) { + // Subclasses may override + } + + public void start(StandardProtocolFamily protocol, SocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + ServerSocketChannel serverChannel = ServerSocketChannel.open(protocol); + serverChannel.configureBlocking(false); + serverChannel.register(this.selector, SelectionKey.OP_ACCEPT); + configureServerSocketChannel(serverChannel, address); + serverChannel.bind(address); + // Start thread/processing of incoming accept, read + super.start((client) -> { + if (logger.isDebugEnabled()) { + logger.debug("Setting client=" + client); + } + this.acceptedClient = client; + if (acceptHandler != null) { + acceptHandler.apply(this.acceptedClient); + } + // No/null connect handler for Acceptors...only accepthandler + }, null, readHandler); + } + + @Override + protected void handleException(SelectionKey key, Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("handleException", e); + } + close(); + } + + public void writeMessageBlocking(String message) throws IOException { + if (this.acceptedClient == null) { + throw new IOException("Cannot write until client connected"); + } + writeBlocking(acceptedClient, message); + } + + @Override + public void close() { + SocketChannel client = this.acceptedClient; + if (client != null) { + hardCloseClient(client, (c) -> { + if (logger.isDebugEnabled()) { + logger.debug("Unsetting client=" + c); + } + this.acceptedClient = null; + }); + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java new file mode 100644 index 00000000..2e279c2b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSClientNonBlockingSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSClientNonBlockingSocketChannel extends ClientNonBlockingSocketChannel { + + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSClientNonBlockingSocketChannel() throws IOException { + super(); + } + + public UDSClientNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSClientNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void connectBlocking(UnixDomainSocketAddress address, IOConsumer connectHandler, + IOConsumer readHandler) throws IOException { + super.connectBlocking(StandardProtocolFamily.UNIX, address, connectHandler, readHandler); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java new file mode 100644 index 00000000..25931571 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/UDSServerNonBlockingSocketChannel.java @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.util; + +import java.io.IOException; +import java.net.StandardProtocolFamily; +import java.net.UnixDomainSocketAddress; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; + +public class UDSServerNonBlockingSocketChannel extends ServerNonBlockingSocketChannel { + + public UDSServerNonBlockingSocketChannel() throws IOException { + super(); + } + + public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize, ExecutorService executor) { + super(selector, incomingBufferSize, executor); + } + + public UDSServerNonBlockingSocketChannel(Selector selector, int incomingBufferSize) { + super(selector, incomingBufferSize); + } + + public UDSServerNonBlockingSocketChannel(Selector selector) { + super(selector); + } + + public void start(UnixDomainSocketAddress address, IOConsumer acceptHandler, + IOConsumer readHandler) throws IOException { + super.start(StandardProtocolFamily.UNIX, address, acceptHandler, readHandler); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java new file mode 100644 index 00000000..99121e01 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpAsyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncClientTests extends AbstractMcpAsyncClientTests { + + UnixDomainSocketAddress address; + EverythingServer server; + + @Override + protected void onStart() { + this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + try { + // Delete this file if exists from previous run + Files.deleteIfExists(this.address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + @Override + protected void onClose() { + server.closeGracefully(); + server = null; + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + address = null; + } + + @Override + protected McpClientTransport createMcpTransport() { + try { + return new UDSClientTransportProvider(address); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java new file mode 100644 index 00000000..c52d98a9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/UDSMcpSyncClientTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; +import java.time.Duration; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.UDSClientTransportProvider; +import io.modelcontextprotocol.server.EverythingServer; +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link UDSClientTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncClientTests extends AbstractMcpSyncClientTests { + + UnixDomainSocketAddress address; + EverythingServer server; + + @Override + protected void onStart() { + this.address = UnixDomainSocketAddress.of(getClass().getName() + ".socket"); + try { + // Delete this file if exists from previous run + Files.deleteIfExists(this.address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + this.server = new EverythingServer(new UDSServerTransportProvider(address)); + } + + @Override + protected void onClose() { + server.closeGracefully(); + server = null; + try { + Files.deleteIfExists(address.getPath()); + } catch (IOException e) { + throw new RuntimeException(e); + } + address = null; + } + + @Override + protected McpClientTransport createMcpTransport() { + try { + return new UDSClientTransportProvider(address); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java new file mode 100644 index 00000000..a158ab2f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/EverythingServer.java @@ -0,0 +1,131 @@ +package io.modelcontextprotocol.server; + +import java.util.List; + +import io.modelcontextprotocol.server.McpServerFeatures.SyncPromptSpecification; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema.Annotations; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.GetPromptResult; +import io.modelcontextprotocol.spec.McpSchema.ProgressNotification; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptMessage; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.Resource; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.SamplingMessage; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest.ContextInclusionStrategy; + +public class EverythingServer { + + private static final String TEST_RESOURCE_URI = "test://resources/"; + + private static final String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + private McpSyncServer server; + + public EverythingServer(McpServerTransportProvider transport) { + McpServerFeatures.SyncResourceSpecification[] specs = new McpServerFeatures.SyncResourceSpecification[10]; + for (int i = 0; i < 10; i++) { + String istr = String.valueOf(i); + String uri = TEST_RESOURCE_URI + istr; + specs[i] = new McpServerFeatures.SyncResourceSpecification( + Resource.builder().uri(uri).name("Test Resource").mimeType("text/plain") + .description("Test resource description").build(), + (exchange, + req) -> new ReadResourceResult(List.of(new TextResourceContents(uri, "text/plain", istr)))); + } + + this.server = McpServer.sync(transport).serverInfo(getClass().getName() + "-server", "1.0.0") + .capabilities( + ServerCapabilities.builder().logging().tools(true).prompts(true).resources(true, true).build()) + .toolCall(Tool.builder().name("echo").description("echo tool description").inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + return CallToolResult.builder().addTextContent((String) request.arguments().get("message")) + .build(); + }) + .toolCall( + Tool.builder().name("add").description("add two integers").inputSchema(emptyJsonSchema).build(), + (exchange, request) -> { + Integer a = (Integer) request.arguments().get("a"); + Integer b = (Integer) request.arguments().get("b"); + + return CallToolResult.builder().addTextContent(String.valueOf(a + b)).build(); + }) + .toolCall(Tool.builder().name("sampleLLM").description("sampleLLM tool").inputSchema(emptyJsonSchema) + .build(), (exchange, request) -> { + String prompt = (String) request.arguments().get("prompt"); + Integer maxTokens = (Integer) request.arguments().get("maxTokens"); + SamplingMessage sm = new SamplingMessage(McpSchema.Role.USER, + new TextContent("Resource sampleLLM context: " + prompt)); + CreateMessageRequest cmRequest = CreateMessageRequest.builder().messages(List.of(sm)) + .systemPrompt("You are a helpful test server.").maxTokens(maxTokens) + .temperature(0.7).includeContext(ContextInclusionStrategy.THIS_SERVER).build(); + CreateMessageResult result = exchange.createMessage(cmRequest); + + return CallToolResult.builder() + .addTextContent("LLM sampling result: " + ((TextContent) result.content()).text()) + .build(); + }) + .toolCall(Tool.builder().name("longRunningOperation") + .description("Demonstrates a long running operation with progress updates") + .inputSchema(emptyJsonSchema).build(), (exchange, request) -> { + String progressToken = (String) request.progressToken(); + int steps = (Integer) request.arguments().get("steps"); + for (int i = 0; i < steps; i++) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + if (progressToken != null) { + exchange.progressNotification( + new ProgressNotification(progressToken, (double) i + 1, (double) steps, + "progress message " + String.valueOf(i + 1))); + } + } + return CallToolResult.builder().content(List.of(new TextContent("done"))).build(); + }) + .toolCall(Tool.builder().name("annotatedMessage").description("annotated message").build(), + (exchange, request) -> { + String messageType = (String) request.arguments().get("messageType"); + Annotations annotations = null; + if (messageType.equals("success")) { + annotations = new Annotations(List.of(McpSchema.Role.USER), 0.7); + } else if (messageType.equals("error")) { + annotations = new Annotations(List.of(McpSchema.Role.USER, McpSchema.Role.ASSISTANT), + 1.0); + } else if (messageType.equals("debug")) { + annotations = new Annotations(List.of(McpSchema.Role.ASSISTANT), 0.3); + } + return CallToolResult.builder().addContent(new TextContent(annotations, "some response")) + .build(); + }) + .prompts(List.of(new SyncPromptSpecification( + new Prompt("simple_prompt", "Simple prompt description", null), (exchange, request) -> { + return new GetPromptResult("description", + List.of(new PromptMessage(Role.USER, new TextContent("hello")))); + }))) + .resources(specs).build(); + } + + public void closeGracefully() { + if (this.server != null) { + this.server.closeGracefully(); + this.server = null; + } + } +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java new file mode 100644 index 00000000..cad1eae5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpAsyncServerTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using {@link UDSServerTransport}. + * + * @author Christian Tzolov + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private UnixDomainSocketAddress address; + + @Override + protected void setUp() { + super.onStart(); + address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); + } + + @Override + protected void tearDown() { + super.onClose(); + if (address != null) { + try { + Files.deleteIfExists(address.getPath()); + } + catch (IOException e) { + } + } + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(address); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java new file mode 100644 index 00000000..57ec7b76 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/UDSMcpSyncServerTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.io.IOException; +import java.net.UnixDomainSocketAddress; +import java.nio.file.Files; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.server.transport.UDSServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using {@link UDSServerTransportProvider}. + * + * @author Christian Tzolov + * @author Scott Lewis + */ +@Timeout(15) // Giving extra time beyond the client timeout +class UDSMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private UnixDomainSocketAddress address; + + @Override + protected void setUp() { + super.onStart(); + address = UnixDomainSocketAddress.of(getClass().getName() + ".unix.socket"); + } + + @Override + protected void tearDown() { + super.onClose(); + if (address != null) { + try { + Files.deleteIfExists(address.getPath()); + } + catch (IOException e) { + } + } + } + + protected McpServerTransportProvider createMcpTransportProvider() { + return new UDSServerTransportProvider(address); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +}