Skip to content

Commit

Permalink
Update to LangChain4j 1.0.0-alpha1
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartisk committed Dec 23, 2024
1 parent 588a42c commit 50eab2f
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
Expand Down Expand Up @@ -61,8 +62,8 @@ void testBlockingToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.hello("abc", "hi - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -92,7 +93,7 @@ void testBlockingToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand All @@ -113,7 +114,7 @@ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, In
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -122,8 +123,8 @@ void testNonBlockingToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -153,7 +154,7 @@ void testNonBlockingToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> result.get() != null);
assertThat(result.get()).contains("Tools", "supported");
assertThat(result.get()).contains("tools", "supported");
}

@Test
Expand Down Expand Up @@ -182,7 +183,7 @@ void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() {
});

Awaitility.await().until(() -> result.get() != null);
assertThat(result.get()).contains("Tools", "supported");
assertThat(result.get()).contains("tools", "supported");
}

@Test
Expand All @@ -204,7 +205,7 @@ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException,
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -213,8 +214,8 @@ void testUniToolInvocationFromWorkerThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloUni("abc", "hiUni - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand Down Expand Up @@ -244,7 +245,7 @@ void testUniToolInvocationFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand All @@ -267,7 +268,7 @@ void testUniToolInvocationFromVirtualThread() throws ExecutionException, Interru
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand All @@ -277,8 +278,8 @@ void testToolInvocationOnVirtualThread() {
String uuid = UUID.randomUUID().toString();
assertThatThrownBy(() -> aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid)
.collect().asList().map(l -> String.join(" ", l)).await().indefinitely())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Tools", "supported");
.isInstanceOf(UnsupportedFeatureException.class)
.hasMessageContaining("tools", "supported");
}

@Test
Expand All @@ -299,7 +300,7 @@ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionExcept
}
}).get();

assertThat(r).contains("Tools", "supported");
assertThat(r).contains("tools", "supported");
}

@Test
Expand Down Expand Up @@ -328,7 +329,7 @@ void testToolInvocationOnVirtualThreadFromEventLoop() {
});

Awaitility.await().until(() -> failure.get() != null || result.get() != null);
assertThat(failure.get()).hasMessageContaining("Tools", "supported");
assertThat(failure.get()).hasMessageContaining("tools", "supported");
assertThat(result.get()).isNull();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
Expand All @@ -31,16 +35,17 @@
* The main difference with the upstream implementation is the thread switch when receiving the `completion` event
* when there is tool execution requests.
*/
public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler<AiMessage> {
public class QuarkusAiServiceStreamingResponseHandler implements StreamingChatResponseHandler {

private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class);

private final AiServiceContext context;
private final Object memoryId;

private final Consumer<String> tokenHandler;
private final Consumer<String> partialResponseHandler;
private final Consumer<Response<AiMessage>> completionHandler;
private final Consumer<ToolExecution> toolExecuteHandler;
private final Consumer<ChatResponse> completeResponseHandler;
private final Consumer<Throwable> errorHandler;

private final List<ChatMessage> temporaryMemory;
Expand All @@ -55,8 +60,9 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon

QuarkusAiServiceStreamingResponseHandler(AiServiceContext context,
Object memoryId,
Consumer<String> tokenHandler,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler,
Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler,
List<ChatMessage> temporaryMemory,
Expand All @@ -69,7 +75,8 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");

this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.toolExecuteHandler = toolExecuteHandler;
this.errorHandler = errorHandler;
Expand All @@ -92,16 +99,19 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
}
}

public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, Consumer<String> tokenHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<Response<AiMessage>> completionHandler,
public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler, List<ChatMessage> temporaryMemory, TokenUsage sum,
List<ToolSpecification> toolSpecifications, Map<String, ToolExecutor> toolExecutors,
boolean mustSwitchToWorkerThread, boolean switchToWorkerForEmission, Context executionContext,
ExecutorService executor) {
this.context = context;
this.memoryId = memoryId;
this.tokenHandler = tokenHandler;
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.toolExecuteHandler = toolExecuteHandler;
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.errorHandler = errorHandler;
this.temporaryMemory = temporaryMemory;
Expand All @@ -115,11 +125,11 @@ public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object
}

@Override
public void onNext(String token) {
public void onPartialResponse(String partialResponse) {
execute(new Runnable() {
@Override
public void run() {
tokenHandler.accept(token);
partialResponseHandler.accept(partialResponse);
}
});

Expand Down Expand Up @@ -156,8 +166,8 @@ public Object call() throws Exception {
}

@Override
public void onComplete(Response<AiMessage> response) {
AiMessage aiMessage = response.content();
public void onCompleteResponse(ChatResponse completeResponse) {
AiMessage aiMessage = completeResponse.aiMessage();

if (aiMessage.hasToolExecutionRequests()) {
// Tools execution may block the caller thread. When the caller thread is the event loop thread, and
Expand All @@ -182,40 +192,61 @@ public void run() {
QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage);
}

context.streamingChatModel.generate(
QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId),
ChatRequest chatRequest = ChatRequest.builder()
.messages(messagesToSend(memoryId))
.toolSpecifications(toolSpecifications)
.build();
QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
partialResponseHandler,
toolExecuteHandler,
completeResponseHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
toolSpecifications,
new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
toolSpecifications,
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor));
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor);
context.streamingChatModel.chat(chatRequest, handler);
}
});
} else {
if (completionHandler != null) {
if (completeResponseHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
ChatResponse finalChatResponse = ChatResponse.builder()
.aiMessage(aiMessage)
.metadata(ChatResponseMetadata.builder()
.id(completeResponse.metadata().id())
.modelName(completeResponse.metadata().modelName())
.tokenUsage(TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()))
.finishReason(completeResponse.metadata().finishReason())
.build())
.build();
addToMemory(aiMessage);
completionHandler.accept(Response.from(
aiMessage,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
response.finishReason()));
completeResponseHandler.accept(finalChatResponse);
} finally {
shutdown(); // Terminal event, we can shutdown the executor
}
}
};
execute(runnable);
} else if (completionHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
Response<AiMessage> finalResponse = Response.from(aiMessage,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
completeResponse.metadata().finishReason());
addToMemory(aiMessage);
completionHandler.accept(finalResponse);
}
};
execute(runnable);
}
}
}
Expand Down
Loading

0 comments on commit 50eab2f

Please sign in to comment.