Skip to content

Commit

Permalink
Update to LangChain4j 1.0.0-alpha1
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartisk committed Dec 23, 2024
1 parent 64f1f22 commit 8ddda9b
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
Expand All @@ -31,16 +35,17 @@
* The main difference with the upstream implementation is the thread switch when receiving the `completion` event
* when there is tool execution requests.
*/
public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler<AiMessage> {
public class QuarkusAiServiceStreamingResponseHandler implements StreamingChatResponseHandler {

private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class);

private final AiServiceContext context;
private final Object memoryId;

private final Consumer<String> tokenHandler;
private final Consumer<String> partialResponseHandler;
private final Consumer<Response<AiMessage>> completionHandler;
private final Consumer<ToolExecution> toolExecuteHandler;
private final Consumer<ChatResponse> completeResponseHandler;
private final Consumer<Throwable> errorHandler;

private final List<ChatMessage> temporaryMemory;
Expand All @@ -55,8 +60,9 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon

QuarkusAiServiceStreamingResponseHandler(AiServiceContext context,
Object memoryId,
Consumer<String> tokenHandler,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler,
Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler,
List<ChatMessage> temporaryMemory,
Expand All @@ -69,7 +75,8 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");

this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.toolExecuteHandler = toolExecuteHandler;
this.errorHandler = errorHandler;
Expand All @@ -92,16 +99,19 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
}
}

public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, Consumer<String> tokenHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<Response<AiMessage>> completionHandler,
public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId,
Consumer<String> partialResponseHandler,
Consumer<ToolExecution> toolExecuteHandler, Consumer<ChatResponse> completeResponseHandler,
Consumer<Response<AiMessage>> completionHandler,
Consumer<Throwable> errorHandler, List<ChatMessage> temporaryMemory, TokenUsage sum,
List<ToolSpecification> toolSpecifications, Map<String, ToolExecutor> toolExecutors,
boolean mustSwitchToWorkerThread, boolean switchToWorkerForEmission, Context executionContext,
ExecutorService executor) {
this.context = context;
this.memoryId = memoryId;
this.tokenHandler = tokenHandler;
this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler");
this.toolExecuteHandler = toolExecuteHandler;
this.completeResponseHandler = completeResponseHandler;
this.completionHandler = completionHandler;
this.errorHandler = errorHandler;
this.temporaryMemory = temporaryMemory;
Expand All @@ -115,11 +125,11 @@ public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object
}

@Override
public void onNext(String token) {
public void onPartialResponse(String partialResponse) {
execute(new Runnable() {
@Override
public void run() {
tokenHandler.accept(token);
partialResponseHandler.accept(partialResponse);
}
});

Expand Down Expand Up @@ -156,8 +166,8 @@ public Object call() throws Exception {
}

@Override
public void onComplete(Response<AiMessage> response) {
AiMessage aiMessage = response.content();
public void onCompleteResponse(ChatResponse completeResponse) {
AiMessage aiMessage = completeResponse.aiMessage();

if (aiMessage.hasToolExecutionRequests()) {
// Tools execution may block the caller thread. When the caller thread is the event loop thread, and
Expand All @@ -182,40 +192,61 @@ public void run() {
QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage);
}

context.streamingChatModel.generate(
QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId),
ChatRequest chatRequest = ChatRequest.builder()
.messages(messagesToSend(memoryId))
.toolSpecifications(toolSpecifications)
.build();
QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
partialResponseHandler,
toolExecuteHandler,
completeResponseHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
toolSpecifications,
new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
tokenHandler,
toolExecuteHandler,
completionHandler,
errorHandler,
temporaryMemory,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
toolSpecifications,
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor));
toolExecutors,
mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor);
context.streamingChatModel.chat(chatRequest, handler);
}
});
} else {
if (completionHandler != null) {
if (completeResponseHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
try {
ChatResponse finalChatResponse = ChatResponse.builder()
.aiMessage(aiMessage)
.metadata(ChatResponseMetadata.builder()
.id(completeResponse.metadata().id())
.modelName(completeResponse.metadata().modelName())
.tokenUsage(TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()))
.finishReason(completeResponse.metadata().finishReason())
.build())
.build();
addToMemory(aiMessage);
completionHandler.accept(Response.from(
aiMessage,
TokenUsage.sum(tokenUsage, response.tokenUsage()),
response.finishReason()));
completeResponseHandler.accept(finalChatResponse);
} finally {
shutdown(); // Terminal event, we can shutdown the executor
}
}
};
execute(runnable);
} else if (completionHandler != null) {
Runnable runnable = new Runnable() {
@Override
public void run() {
Response<AiMessage> finalResponse = Response.from(aiMessage,
TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()),
completeResponse.metadata().finishReason());
addToMemory(aiMessage);
completionHandler.accept(finalResponse);
}
};
execute(runnable);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import static dev.langchain4j.internal.Utils.copyIfNotNull;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Collections.emptyList;
Expand All @@ -15,6 +14,8 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.content.Content;
Expand Down Expand Up @@ -42,13 +43,16 @@ public class QuarkusAiServiceTokenStream implements TokenStream {
private final boolean switchToWorkerThreadForToolExecution;
private final boolean switchToWorkerForEmission;

private Consumer<String> tokenHandler;
private Consumer<String> partialResponseHandler;
private Consumer<List<Content>> contentsHandler;
private Consumer<Throwable> errorHandler;
private Consumer<Response<AiMessage>> completionHandler;
private Consumer<ToolExecution> toolExecuteHandler;
private Consumer<ChatResponse> completeResponseHandler;

private int onPartialResponseInvoked;
private int onNextInvoked;
private int onCompleteResponseInvoked;
private int onCompleteInvoked;
private int onRetrievedInvoked;
private int onErrorInvoked;
Expand All @@ -74,9 +78,16 @@ public QuarkusAiServiceTokenStream(List<ChatMessage> messages,
this.switchToWorkerForEmission = switchToWorkerForEmission;
}

@Override
public TokenStream onPartialResponse(Consumer<String> partialResponseHandler) {
this.partialResponseHandler = partialResponseHandler;
this.onPartialResponseInvoked++;
return this;
}

@Override
public TokenStream onNext(Consumer<String> tokenHandler) {
this.tokenHandler = tokenHandler;
this.partialResponseHandler = tokenHandler;
this.onNextInvoked++;
return this;
}
Expand All @@ -95,6 +106,13 @@ public TokenStream onToolExecuted(Consumer<ToolExecution> toolExecuteHandler) {
return this;
}

@Override
public TokenStream onCompleteResponse(Consumer<ChatResponse> completionHandler) {
this.completeResponseHandler = completionHandler;
this.onCompleteResponseInvoked++;
return this;
}

@Override
public TokenStream onComplete(Consumer<Response<AiMessage>> completionHandler) {
this.completionHandler = completionHandler;
Expand All @@ -119,11 +137,17 @@ public TokenStream ignoreErrors() {
@Override
public void start() {
validateConfiguration();
ChatRequest chatRequest = new ChatRequest.Builder()
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();

QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler(
context,
memoryId,
tokenHandler,
partialResponseHandler,
toolExecuteHandler,
completeResponseHandler,
completionHandler,
errorHandler,
initTemporaryMemory(context, messages),
Expand All @@ -138,39 +162,38 @@ public void start() {
contentsHandler.accept(retrievedContents);
}

if (isNullOrEmpty(toolSpecifications)) {
context.streamingChatModel.generate(messages, handler);
} else {
try {
// Some model do not support function calling with tool specifications
context.streamingChatModel.generate(messages, toolSpecifications, handler);
} catch (Exception e) {
if (errorHandler != null) {
errorHandler.accept(e);
}
try {
// Some model do not support function calling with tool specifications
context.streamingChatModel.chat(chatRequest, handler);
} catch (Exception e) {
if (errorHandler != null) {
errorHandler.accept(e);
}
}
}

private void validateConfiguration() {
if (onNextInvoked != 1) {
throw new IllegalConfigurationException("onNext must be invoked exactly 1 time");
if (onPartialResponseInvoked + onNextInvoked != 1) {
throw new IllegalConfigurationException("One of [onPartialResponse, onNext] " +
"must be invoked on TokenStream exactly 1 time");
}

if (onCompleteInvoked > 1) {
throw new IllegalConfigurationException("onComplete must be invoked at most 1 time");
if (onCompleteResponseInvoked + onCompleteInvoked > 1) {
throw new IllegalConfigurationException("One of [onCompleteResponse, onComplete] " +
"can be invoked on TokenStream at most 1 time");
}

if (onRetrievedInvoked > 1) {
throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time");
throw new IllegalConfigurationException("onRetrieved can be invoked on TokenStream at most 1 time");
}

if (toolExecuteInvoked > 1) {
throw new IllegalConfigurationException("onToolExecuted must be invoked at most 1 time");
throw new IllegalConfigurationException("onToolExecuted can be invoked on TokenStream at most 1 time");
}

if (onErrorInvoked + ignoreErrorsInvoked != 1) {
throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time");
throw new IllegalConfigurationException("One of [onError, ignoreErrors] " +
"must be invoked on TokenStream exactly 1 time");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.openai.OpenAiStreamingResponseBuilder;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient;
Expand Down Expand Up @@ -202,7 +203,7 @@ private void generate(List<ChatMessage> messages,
}
})
.onComplete(() -> {
Response<AiMessage> response = responseBuilder.build();
ChatResponse response = responseBuilder.build();

ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
Expand All @@ -220,10 +221,13 @@ private void generate(List<ChatMessage> messages,
}
});

handler.onComplete(response);
Response<AiMessage> aiResponse = Response.from(response.aiMessage(),
response.tokenUsage(),
response.finishReason());
handler.onComplete(aiResponse);
})
.onError((error) -> {
Response<AiMessage> response = responseBuilder.build();
ChatResponse response = responseBuilder.build();

ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
responseId.get(),
Expand Down Expand Up @@ -282,7 +286,7 @@ private ChatModelRequest createModelListenerRequest(ChatCompletionRequest reques

private ChatModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
ChatResponse response) {
if (response == null) {
return null;
}
Expand All @@ -292,7 +296,7 @@ private ChatModelResponse createModelListenerResponse(String responseId,
.model(responseModel)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.aiMessage(response.aiMessage())
.build();
}

Expand Down
Loading

0 comments on commit 8ddda9b

Please sign in to comment.