Skip to content

Commit

Permalink
Merge pull request #1171 from jmartisk/mcp
Browse files Browse the repository at this point in the history
LangChain4j 1.0.0-alpha1 & Model Context Protocol client implementation and sample
  • Loading branch information
geoand authored Dec 23, 2024
2 parents b7f861a + f7d76a0 commit 6274d69
Show file tree
Hide file tree
Showing 55 changed files with 2,409 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
Expand Down Expand Up @@ -61,8 +62,8 @@ void testBlockingToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.hello("abc", "hi - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -92,7 +93,7 @@ void testBlockingToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand All @@ -113,7 +114,7 @@ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, In
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -122,8 +123,8 @@ void testNonBlockingToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -153,7 +154,7 @@ void testNonBlockingToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> result.get() != null);
assertThat(result.get()).contains("Tools", "supported");
assertThat(result.get()).contains("tools", "supported");
}

@Test
Expand Down Expand Up @@ -182,7 +183,7 @@ void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() {
});

Awaitility.await().until(() -> result.get() != null);
assertThat(result.get()).contains("Tools", "supported");
assertThat(result.get()).contains("tools", "supported");
}

@Test
Expand All @@ -204,7 +205,7 @@ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException,
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -213,8 +214,8 @@ void testUniToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloUni("abc", "hiUni - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -244,7 +245,7 @@ void testUniToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand All @@ -267,7 +268,7 @@ void testUniToolInvocationFromVirtualThread() throws ExecutionException, Interru
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -277,8 +278,8 @@ void testToolInvocationOnVirtualThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand All @@ -299,7 +300,7 @@ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionExcept
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand Down Expand Up @@ -328,7 +329,7 @@ void testToolInvocationOnVirtualThreadFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
Expand All @@ -31,16 +35,17 @@
* The main difference with the upstream implementation is the thread switch when receiving the `completion` event
* when there is tool execution requests.
*/
public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler<AiMessage> {
public class QuarkusAiServiceStreamingResponseHandler implements StreamingChatResponseHandler {

private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class);

private final AiServiceContext context;
private final Object memoryId;

private final Consumer<String> tokenHandler;
private final Consumer<String> partialResponseHandler;
private final Consumer<Response<AiMessage>> completionHandler;
private final Consumer<ToolExecution> toolExecuteHandler;
private final Consumer<ChatResponse> completeResponseHandler;
private final Consumer<Throwable> errorHandler;

private final List<ChatMessage> temporaryMemory;
Expand All @@ -55,8 +60,9 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon

QuarkusAiServiceStreamingResponseHandler(AiServiceContext context,
Object memoryId,
Consumer<String> tokenHandler,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler,
Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler,
List<ChatMessage> temporaryMemory,
Expand All @@ -69,7 +75,8 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");

this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.toolExecuteHandler = toolExecuteHandler;
this.errorHandler = errorHandler;
Expand All @@ -92,16 +99,19 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
}
}

public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, Consumer<String> tokenHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<Response<AiMessage>> completionHandler,
public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler, List<ChatMessage> temporaryMemory, TokenUsage sum,
List<ToolSpecification> toolSpecifications, Map<String, ToolExecutor> toolExecutors,
boolean mustSwitchToWorkerThread, boolean switchToWorkerForEmission, Context executionContext,
ExecutorService executor) {
this.context = context;
this.memoryId = memoryId;
this.tokenHandler = tokenHandler;
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.toolExecuteHandler = toolExecuteHandler;
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.errorHandler = errorHandler;
this.temporaryMemory = temporaryMemory;
Expand All @@ -115,11 +125,11 @@ public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object
}

@Override
public void onNext(String token) {
public void onPartialResponse(String partialResponse) {
execute(new Runnable() {
@Override
public void run() {
tokenHandler.accept(token);
partialResponseHandler.accept(partialResponse);
}
});

Expand Down Expand Up @@ -156,8 +166,8 @@ public Object call() throws Exception {
}

@Override
public void onComplete(Response<AiMessage> response) {
AiMessage aiMessage = response.content();
public void onCompleteResponse(ChatResponse completeResponse) {
AiMessage aiMessage = completeResponse.aiMessage();

if (aiMessage.hasToolExecutionRequests()) {
// Tools execution may block the caller thread. When the caller thread is the event loop thread, and
Expand All @@ -182,40 +192,61 @@ public void run() {
QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage);
}

context.streamingChatModel.generate(
QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId),
ChatRequest chatRequest = ChatRequest.builder()
.messages(messagesToSend(memoryId))
.toolSpecifications(toolSpecifications)
.build();
QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
partialResponseHandler,
toolExecuteHandler,
completeResponseHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
toolSpecifications,
new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
toolSpecifications,
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor));
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor);
context.streamingChatModel.chat(chatRequest, handler);
}
});
} else {
if (completionHandler != null) {
if (completeResponseHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
ChatResponse finalChatResponse = ChatResponse.builder()
.aiMessage(aiMessage)
.metadata(ChatResponseMetadata.builder()
.id(completeResponse.metadata().id())
.modelName(completeResponse.metadata().modelName())
.tokenUsage(TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()))
.finishReason(completeResponse.metadata().finishReason())
.build())
.build();
addToMemory(aiMessage);
completionHandler.accept(Response.from(
aiMessage,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
response.finishReason()));
completeResponseHandler.accept(finalChatResponse);
} finally {
shutdown(); // Terminal event, we can shutdown the executor
}
}
};
execute(runnable);
} else if (completionHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
Response<AiMessage> finalResponse = Response.from(aiMessage,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
completeResponse.metadata().finishReason());
addToMemory(aiMessage);
completionHandler.accept(finalResponse);
}
};
execute(runnable);
}
}
}
Expand Down
Loading

0 comments on commit 6274d69

Please sign in to comment.