Skip to content

Commit

Permalink
Update to LangChain4j 1.0.0-alpha1
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartisk committed Dec 23, 2024
1 parent 0392413 commit 4d8de9b
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
Expand Down Expand Up @@ -61,8 +62,8 @@ void testBlockingToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.hello("abc", "hi - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -92,7 +93,7 @@ void testBlockingToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand All @@ -113,7 +114,7 @@ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, In
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -122,8 +123,8 @@ void testNonBlockingToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -153,7 +154,7 @@ void testNonBlockingToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> result.get() != null);
assertThat(result.get()).contains("Tools", "supported");
assertThat(result.get()).contains("tools", "supported");
}

@Test
Expand Down Expand Up @@ -182,7 +183,7 @@ void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() {
});

Awaitility.await().until(() -> result.get() != null);
assertThat(result.get()).contains("Tools", "supported");
assertThat(result.get()).contains("tools", "supported");
}

@Test
Expand All @@ -204,7 +205,7 @@ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException,
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -213,8 +214,8 @@ void testUniToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloUni("abc", "hiUni - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -244,7 +245,7 @@ void testUniToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand All @@ -267,7 +268,7 @@ void testUniToolInvocationFromVirtualThread() throws ExecutionException, Interru
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -277,8 +278,8 @@ void testToolInvocationOnVirtualThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand All @@ -299,7 +300,7 @@ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionExcept
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand Down Expand Up @@ -328,7 +329,7 @@ void testToolInvocationOnVirtualThreadFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
Expand All @@ -31,16 +35,17 @@
* The main difference with the upstream implementation is the thread switch when receiving the `completion` event
* when there is tool execution requests.
*/
public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler<AiMessage> {
public class QuarkusAiServiceStreamingResponseHandler implements StreamingChatResponseHandler {

private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class);

private final AiServiceContext context;
private final Object memoryId;

private final Consumer<String> tokenHandler;
private final Consumer<String> partialResponseHandler;
private final Consumer<Response<AiMessage>> completionHandler;
private final Consumer<ToolExecution> toolExecuteHandler;
private final Consumer<ChatResponse> completeResponseHandler;
private final Consumer<Throwable> errorHandler;

private final List<ChatMessage> temporaryMemory;
Expand All @@ -55,8 +60,9 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon

QuarkusAiServiceStreamingResponseHandler(AiServiceContext context,
Object memoryId,
Consumer<String> tokenHandler,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler,
Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler,
List<ChatMessage> temporaryMemory,
Expand All @@ -69,7 +75,8 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");

this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.toolExecuteHandler = toolExecuteHandler;
this.errorHandler = errorHandler;
Expand All @@ -92,16 +99,19 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
}
}

public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, Consumer<String> tokenHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<Response<AiMessage>> completionHandler,
public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler, List<ChatMessage> temporaryMemory, TokenUsage sum,
List<ToolSpecification> toolSpecifications, Map<String, ToolExecutor> toolExecutors,
boolean mustSwitchToWorkerThread, boolean switchToWorkerForEmission, Context executionContext,
ExecutorService executor) {
this.context = context;
this.memoryId = memoryId;
this.tokenHandler = tokenHandler;
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.toolExecuteHandler = toolExecuteHandler;
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.errorHandler = errorHandler;
this.temporaryMemory = temporaryMemory;
Expand All @@ -115,11 +125,11 @@ public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object
}

@Override
public void onNext(String token) {
public void onPartialResponse(String partialResponse) {
execute(new Runnable() {
@Override
public void run() {
tokenHandler.accept(token);
partialResponseHandler.accept(partialResponse);
}
});

Expand Down Expand Up @@ -156,8 +166,8 @@ public Object call() throws Exception {
}

@Override
public void onComplete(Response<AiMessage> response) {
AiMessage aiMessage = response.content();
public void onCompleteResponse(ChatResponse completeResponse) {
AiMessage aiMessage = completeResponse.aiMessage();

if (aiMessage.hasToolExecutionRequests()) {
// Tools execution may block the caller thread. When the caller thread is the event loop thread, and
Expand All @@ -182,40 +192,61 @@ public void run() {
QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage);
}

context.streamingChatModel.generate(
QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId),
ChatRequest chatRequest = ChatRequest.builder()
.messages(messagesToSend(memoryId))
.toolSpecifications(toolSpecifications)
.build();
QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
partialResponseHandler,
toolExecuteHandler,
completeResponseHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
toolSpecifications,
new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
toolSpecifications,
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor));
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor);
context.streamingChatModel.chat(chatRequest, handler);
}
});
} else {
if (completionHandler != null) {
if (completeResponseHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
ChatResponse finalChatResponse = ChatResponse.builder()
.aiMessage(aiMessage)
.metadata(ChatResponseMetadata.builder()
.id(completeResponse.metadata().id())
.modelName(completeResponse.metadata().modelName())
.tokenUsage(TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()))
.finishReason(completeResponse.metadata().finishReason())
.build())
.build();
addToMemory(aiMessage);
completionHandler.accept(Response.from(
aiMessage,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
response.finishReason()));
completeResponseHandler.accept(finalChatResponse);
} finally {
shutdown(); // Terminal event, we can shutdown the executor
}
}
};
execute(runnable);
} else if (completionHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
Response<AiMessage> finalResponse = Response.from(aiMessage,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
completeResponse.metadata().finishReason());
addToMemory(aiMessage);
completionHandler.accept(finalResponse);
}
};
execute(runnable);
}
}
}
Expand Down
Loading

0 comments on commit 4d8de9b

Please sign in to comment.