From f7d76a0a70189327e26cf81080bba3a053d0d70b Mon Sep 17 00:00:00 2001 From: Jan Martiska Date: Mon, 23 Dec 2024 11:09:49 +0100 Subject: [PATCH] Update to LangChain4j 1.0.0-alpha1 --- ...tionModelWithStreamingUnsupportedTest.java | 35 +++---- ...rkusAiServiceStreamingResponseHandler.java | 91 +++++++++++++------ .../QuarkusAiServiceTokenStream.java | 65 ++++++++----- .../chroma/ChromaEmbeddingStore.java | 5 + .../infinispan/InfinispanEmbeddingStore.java | 5 + .../pinecone/PineconeEmbeddingStore.java | 5 + .../redis/RedisEmbeddingStore.java | 5 + .../deployment/OllamaJsonOutputTest.java | 5 +- .../deployment/OllamaTextOutputTest.java | 3 +- .../openai/AzureOpenAiStreamingChatModel.java | 14 ++- .../langchain4j/watsonx/WatsonxChatModel.java | 13 +++ .../watsonx/WatsonxGenerationModel.java | 13 +++ pom.xml | 4 +- 13 files changed, 183 insertions(+), 80 deletions(-) diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingUnsupportedTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingUnsupportedTest.java index 2ae8b683f..b877a7ffa 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingUnsupportedTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingUnsupportedTest.java @@ -24,6 +24,7 @@ import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.exception.UnsupportedFeatureException; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -61,8 +62,8 @@ void testBlockingToolInvocationFromWorkerThread() { String uuid = UUID.randomUUID().toString(); assertThatThrownBy(() -> aiService.hello("abc", "hi - " + uuid) .collect().asList().map(l -> String.join(" ", l)).await().indefinitely()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Tools", "supported"); + .isInstanceOf(UnsupportedFeatureException.class) + .hasMessageContaining("tools", "supported"); } @Test @@ -92,7 +93,7 @@ void testBlockingToolInvocationFromEventLoop() { }); Awaitility.await().until(() -> failure.get() != null || result.get() != null); - assertThat(failure.get()).hasMessageContaining("Tools", "supported"); + assertThat(failure.get()).hasMessageContaining("tools", "supported"); assertThat(result.get()).isNull(); } @@ -113,7 +114,7 @@ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, In } }).get(); - assertThat(r).contains("Tools", "supported"); + assertThat(r).contains("tools", "supported"); } @Test @@ -122,8 +123,8 @@ void testNonBlockingToolInvocationFromWorkerThread() { String uuid = UUID.randomUUID().toString(); assertThatThrownBy(() -> aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid) .collect().asList().map(l -> String.join(" ", l)).await().indefinitely()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Tools", "supported"); + .isInstanceOf(UnsupportedFeatureException.class) + .hasMessageContaining("tools", "supported"); } @Test @@ -153,7 +154,7 @@ void testNonBlockingToolInvocationFromEventLoop() { }); Awaitility.await().until(() -> result.get() != null); - assertThat(result.get()).contains("Tools", "supported"); + assertThat(result.get()).contains("tools", "supported"); } @Test @@ -182,7 +183,7 @@ void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() { }); Awaitility.await().until(() -> result.get() != null); - assertThat(result.get()).contains("Tools", "supported"); + assertThat(result.get()).contains("tools", "supported"); } @Test @@ -204,7 +205,7 @@ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, } }).get(); - assertThat(r).contains("Tools", "supported"); + assertThat(r).contains("tools", "supported"); } @Test @@ -213,8 +214,8 @@ void testUniToolInvocationFromWorkerThread() { String uuid = UUID.randomUUID().toString(); assertThatThrownBy(() -> aiService.helloUni("abc", "hiUni - " + uuid) .collect().asList().map(l -> String.join(" ", l)).await().indefinitely()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Tools", "supported"); + .isInstanceOf(UnsupportedFeatureException.class) + .hasMessageContaining("tools", "supported"); } @Test @@ -244,7 +245,7 @@ void testUniToolInvocationFromEventLoop() { }); Awaitility.await().until(() -> failure.get() != null || result.get() != null); - assertThat(failure.get()).hasMessageContaining("Tools", "supported"); + assertThat(failure.get()).hasMessageContaining("tools", "supported"); assertThat(result.get()).isNull(); } @@ -267,7 +268,7 @@ void testUniToolInvocationFromVirtualThread() throws ExecutionException, Interru } }).get(); - assertThat(r).contains("Tools", "supported"); + assertThat(r).contains("tools", "supported"); } @Test @@ -277,8 +278,8 @@ void testToolInvocationOnVirtualThread() { String uuid = UUID.randomUUID().toString(); assertThatThrownBy(() -> aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid) .collect().asList().map(l -> String.join(" ", l)).await().indefinitely()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Tools", "supported"); + .isInstanceOf(UnsupportedFeatureException.class) + .hasMessageContaining("tools", "supported"); } @Test @@ -299,7 +300,7 @@ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionExcept } }).get(); - assertThat(r).contains("Tools", "supported"); + assertThat(r).contains("tools", "supported"); } @Test @@ -328,7 +329,7 @@ void testToolInvocationOnVirtualThreadFromEventLoop() { }); Awaitility.await().until(() -> failure.get() != null || result.get() != null); - assertThat(failure.get()).hasMessageContaining("Tools", "supported"); + assertThat(failure.get()).hasMessageContaining("tools", "supported"); assertThat(result.get()).isNull(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java index 60d9aa18e..f8d984c95 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java @@ -19,6 +19,10 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; +import dev.langchain4j.model.chat.response.ChatResponseMetadata; +import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.service.AiServiceContext; @@ -31,16 +35,17 @@ * The main difference with the upstream implementation is the thread switch when receiving the `completion` event * when there is tool execution requests. */ -public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler { +public class QuarkusAiServiceStreamingResponseHandler implements StreamingChatResponseHandler { private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class); private final AiServiceContext context; private final Object memoryId; - private final Consumer tokenHandler; + private final Consumer partialResponseHandler; private final Consumer> completionHandler; private final Consumer toolExecuteHandler; + private final Consumer completeResponseHandler; private final Consumer errorHandler; private final List temporaryMemory; @@ -55,8 +60,9 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, - Consumer tokenHandler, + Consumer partialResponseHandler, Consumer toolExecuteHandler, + Consumer completeResponseHandler, Consumer> completionHandler, Consumer errorHandler, List temporaryMemory, @@ -69,7 +75,8 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon this.context = ensureNotNull(context, "context"); this.memoryId = ensureNotNull(memoryId, "memoryId"); - this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler"); + this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler"); + this.completeResponseHandler = completeResponseHandler; this.completionHandler = completionHandler; this.toolExecuteHandler = toolExecuteHandler; this.errorHandler = errorHandler; @@ -92,16 +99,19 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon } } - public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, Consumer tokenHandler, - Consumer toolExecuteHandler, Consumer> completionHandler, + public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, + Consumer partialResponseHandler, + Consumer toolExecuteHandler, Consumer completeResponseHandler, + Consumer> completionHandler, Consumer errorHandler, List temporaryMemory, TokenUsage sum, List toolSpecifications, Map toolExecutors, boolean mustSwitchToWorkerThread, boolean switchToWorkerForEmission, Context executionContext, ExecutorService executor) { this.context = context; this.memoryId = memoryId; - this.tokenHandler = tokenHandler; + this.partialResponseHandler = ensureNotNull(partialResponseHandler, "partialResponseHandler"); this.toolExecuteHandler = toolExecuteHandler; + this.completeResponseHandler = completeResponseHandler; this.completionHandler = completionHandler; this.errorHandler = errorHandler; this.temporaryMemory = temporaryMemory; @@ -115,11 +125,11 @@ public QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object } @Override - public void onNext(String token) { + public void onPartialResponse(String partialResponse) { execute(new Runnable() { @Override public void run() { - tokenHandler.accept(token); + partialResponseHandler.accept(partialResponse); } }); @@ -156,8 +166,8 @@ public Object call() throws Exception { } @Override - public void onComplete(Response response) { - AiMessage aiMessage = response.content(); + public void onCompleteResponse(ChatResponse completeResponse) { + AiMessage aiMessage = completeResponse.aiMessage(); if (aiMessage.hasToolExecutionRequests()) { // Tools execution may block the caller thread. When the caller thread is the event loop thread, and @@ -182,40 +192,61 @@ public void run() { QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage); } - context.streamingChatModel.generate( - QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId), + ChatRequest chatRequest = ChatRequest.builder() + .messages(messagesToSend(memoryId)) + .toolSpecifications(toolSpecifications) + .build(); + QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler( + context, + memoryId, + partialResponseHandler, + toolExecuteHandler, + completeResponseHandler, + completionHandler, + errorHandler, + temporaryMemory, + TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()), toolSpecifications, - new QuarkusAiServiceStreamingResponseHandler( - context, - memoryId, - tokenHandler, - toolExecuteHandler, - completionHandler, - errorHandler, - temporaryMemory, - TokenUsage.sum(tokenUsage, response.tokenUsage()), - toolSpecifications, - toolExecutors, - mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor)); + toolExecutors, + mustSwitchToWorkerThread, switchToWorkerForEmission, executionContext, executor); + context.streamingChatModel.chat(chatRequest, handler); } }); } else { - if (completionHandler != null) { + if (completeResponseHandler != null) { Runnable runnable = new Runnable() { @Override public void run() { try { + ChatResponse finalChatResponse = ChatResponse.builder() + .aiMessage(aiMessage) + .metadata(ChatResponseMetadata.builder() + .id(completeResponse.metadata().id()) + .modelName(completeResponse.metadata().modelName()) + .tokenUsage(TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage())) + .finishReason(completeResponse.metadata().finishReason()) + .build()) + .build(); addToMemory(aiMessage); - completionHandler.accept(Response.from( - aiMessage, - TokenUsage.sum(tokenUsage, response.tokenUsage()), - response.finishReason())); + completeResponseHandler.accept(finalChatResponse); } finally { shutdown(); // Terminal event, we can shutdown the executor } } }; execute(runnable); + } else if (completionHandler != null) { + Runnable runnable = new Runnable() { + @Override + public void run() { + Response finalResponse = Response.from(aiMessage, + TokenUsage.sum(tokenUsage, completeResponse.metadata().tokenUsage()), + completeResponse.metadata().finishReason()); + addToMemory(aiMessage); + completionHandler.accept(finalResponse); + } + }; + execute(runnable); } } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java index 57a9b8276..0c9e77c22 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java @@ -1,7 +1,6 @@ package io.quarkiverse.langchain4j.runtime.aiservice; import static dev.langchain4j.internal.Utils.copyIfNotNull; -import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import static java.util.Collections.emptyList; @@ -15,6 +14,8 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.exception.IllegalConfigurationException; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.rag.content.Content; @@ -42,13 +43,16 @@ public class QuarkusAiServiceTokenStream implements TokenStream { private final boolean switchToWorkerThreadForToolExecution; private final boolean switchToWorkerForEmission; - private Consumer tokenHandler; + private Consumer partialResponseHandler; private Consumer> contentsHandler; private Consumer errorHandler; private Consumer> completionHandler; private Consumer toolExecuteHandler; + private Consumer completeResponseHandler; + private int onPartialResponseInvoked; private int onNextInvoked; + private int onCompleteResponseInvoked; private int onCompleteInvoked; private int onRetrievedInvoked; private int onErrorInvoked; @@ -74,9 +78,16 @@ public QuarkusAiServiceTokenStream(List messages, this.switchToWorkerForEmission = switchToWorkerForEmission; } + @Override + public TokenStream onPartialResponse(Consumer partialResponseHandler) { + this.partialResponseHandler = partialResponseHandler; + this.onPartialResponseInvoked++; + return this; + } + @Override public TokenStream onNext(Consumer tokenHandler) { - this.tokenHandler = tokenHandler; + this.partialResponseHandler = tokenHandler; this.onNextInvoked++; return this; } @@ -95,6 +106,13 @@ public TokenStream onToolExecuted(Consumer toolExecuteHandler) { return this; } + @Override + public TokenStream onCompleteResponse(Consumer completionHandler) { + this.completeResponseHandler = completionHandler; + this.onCompleteResponseInvoked++; + return this; + } + @Override public TokenStream onComplete(Consumer> completionHandler) { this.completionHandler = completionHandler; @@ -119,11 +137,17 @@ public TokenStream ignoreErrors() { @Override public void start() { validateConfiguration(); + ChatRequest chatRequest = new ChatRequest.Builder() + .messages(messages) + .toolSpecifications(toolSpecifications) + .build(); + QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler( context, memoryId, - tokenHandler, + partialResponseHandler, toolExecuteHandler, + completeResponseHandler, completionHandler, errorHandler, initTemporaryMemory(context, messages), @@ -138,39 +162,38 @@ public void start() { contentsHandler.accept(retrievedContents); } - if (isNullOrEmpty(toolSpecifications)) { - context.streamingChatModel.generate(messages, handler); - } else { - try { - // Some model do not support function calling with tool specifications - context.streamingChatModel.generate(messages, toolSpecifications, handler); - } catch (Exception e) { - if (errorHandler != null) { - errorHandler.accept(e); - } + try { + // Some model do not support function calling with tool specifications + context.streamingChatModel.chat(chatRequest, handler); + } catch (Exception e) { + if (errorHandler != null) { + errorHandler.accept(e); } } } private void validateConfiguration() { - if (onNextInvoked != 1) { - throw new IllegalConfigurationException("onNext must be invoked exactly 1 time"); + if (onPartialResponseInvoked + onNextInvoked != 1) { + throw new IllegalConfigurationException("One of [onPartialResponse, onNext] " + + "must be invoked on TokenStream exactly 1 time"); } - if (onCompleteInvoked > 1) { - throw new IllegalConfigurationException("onComplete must be invoked at most 1 time"); + if (onCompleteResponseInvoked + onCompleteInvoked > 1) { + throw new IllegalConfigurationException("One of [onCompleteResponse, onComplete] " + + "can be invoked on TokenStream at most 1 time"); } if (onRetrievedInvoked > 1) { - throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time"); + throw new IllegalConfigurationException("onRetrieved can be invoked on TokenStream at most 1 time"); } if (toolExecuteInvoked > 1) { - throw new IllegalConfigurationException("onToolExecuted must be invoked at most 1 time"); + throw new IllegalConfigurationException("onToolExecuted can be invoked on TokenStream at most 1 time"); } if (onErrorInvoked + ignoreErrorsInvoked != 1) { - throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time"); + throw new IllegalConfigurationException("One of [onError, ignoreErrors] " + + "must be invoked on TokenStream exactly 1 time"); } } diff --git a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java index 57e47d1e9..5b2023c1f 100644 --- a/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java +++ b/embedding-stores/chroma/runtime/src/main/java/io/quarkiverse/langchain4j/chroma/ChromaEmbeddingStore.java @@ -179,6 +179,11 @@ public List addAll(List embeddings, List textSeg return ids; } + @Override + public void addAll(List ids, List embeddings, List embedded) { + addAllInternal(ids, embeddings, embedded); + } + private void addInternal(String id, Embedding embedding, TextSegment textSegment) { addAllInternal(singletonList(id), singletonList(embedding), textSegment == null ? null : singletonList(textSegment)); } diff --git a/embedding-stores/infinispan/runtime/src/main/java/io/quarkiverse/langchain4j/infinispan/InfinispanEmbeddingStore.java b/embedding-stores/infinispan/runtime/src/main/java/io/quarkiverse/langchain4j/infinispan/InfinispanEmbeddingStore.java index d1bd4bc45..f710997e3 100644 --- a/embedding-stores/infinispan/runtime/src/main/java/io/quarkiverse/langchain4j/infinispan/InfinispanEmbeddingStore.java +++ b/embedding-stores/infinispan/runtime/src/main/java/io/quarkiverse/langchain4j/infinispan/InfinispanEmbeddingStore.java @@ -87,6 +87,11 @@ public List addAll(List embeddings, List embedde return ids; } + @Override + public void addAll(List ids, List embeddings, List embedded) { + addAllInternal(ids, embeddings, embedded); + } + private void addInternal(String id, Embedding embedding, TextSegment embedded) { addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); } diff --git a/embedding-stores/pinecone/runtime/src/main/java/io/quarkiverse/langchain4j/pinecone/PineconeEmbeddingStore.java b/embedding-stores/pinecone/runtime/src/main/java/io/quarkiverse/langchain4j/pinecone/PineconeEmbeddingStore.java index c72e3c9ed..fb3c8402f 100644 --- a/embedding-stores/pinecone/runtime/src/main/java/io/quarkiverse/langchain4j/pinecone/PineconeEmbeddingStore.java +++ b/embedding-stores/pinecone/runtime/src/main/java/io/quarkiverse/langchain4j/pinecone/PineconeEmbeddingStore.java @@ -150,6 +150,11 @@ public List addAll(List embeddings, List embedde return ids; } + @Override + public void addAll(List ids, List embeddings, List embedded) { + addAllInternal(ids, embeddings, embedded); + } + @Override public List> findRelevant(Embedding embedding, int maxResults, double minScore) { indexExists.get(); diff --git a/embedding-stores/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java b/embedding-stores/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java index 93d447e1e..264955c64 100644 --- a/embedding-stores/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java +++ b/embedding-stores/redis/runtime/src/main/java/io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.java @@ -120,6 +120,11 @@ public List addAll(List embeddings, List embedde return ids; } + @Override + public void addAll(List ids, List embeddings, List embedded) { + addAllInternal(ids, embeddings, embedded); + } + private void addInternal(String id, Embedding embedding, TextSegment embedded) { addAllInternal(singletonList(id), singletonList(embedding), embedded == null ? null : singletonList(embedded)); } diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java index 4d4ebc6c1..1055ff3da 100644 --- a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaJsonOutputTest.java @@ -58,14 +58,13 @@ void extract() { "content": "Tell me something about Alan Wake\\nYou must answer strictly in the following JSON format: {\\n\\\"firstname\\\": (The firstname; type: string),\\n\\\"lastname\\\": (The lastname; type: string)\\n}" } ], - "stream": false, "options": { "temperature": 0.8, "top_k": 40, "top_p": 0.9 }, - "tools": [], - "format": "json" + "format": "json", + "stream": false }""")) .willReturn(aResponse() .withHeader("Content-Type", "application/json") diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java index 48f08eea0..81d00f777 100644 --- a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaTextOutputTest.java @@ -55,8 +55,7 @@ void extract() { "temperature": 0.8, "top_k": 40, "top_p": 0.9 - }, - "tools": [] + } }""")) .willReturn(aResponse() .withHeader("Content-Type", "application/json") diff --git a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java index b3c1ce7d4..d28b15025 100644 --- a/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java +++ b/model-providers/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/AzureOpenAiStreamingChatModel.java @@ -41,6 +41,7 @@ import dev.langchain4j.model.chat.listener.ChatModelRequestContext; import dev.langchain4j.model.chat.listener.ChatModelResponse; import dev.langchain4j.model.chat.listener.ChatModelResponseContext; +import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.openai.OpenAiStreamingResponseBuilder; import dev.langchain4j.model.output.Response; import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient; @@ -202,7 +203,7 @@ private void generate(List messages, } }) .onComplete(() -> { - Response response = responseBuilder.build(); + ChatResponse response = responseBuilder.build(); ChatModelResponse modelListenerResponse = createModelListenerResponse( responseId.get(), @@ -220,10 +221,13 @@ private void generate(List messages, } }); - handler.onComplete(response); + Response aiResponse = Response.from(response.aiMessage(), + response.tokenUsage(), + response.finishReason()); + handler.onComplete(aiResponse); }) .onError((error) -> { - Response response = responseBuilder.build(); + ChatResponse response = responseBuilder.build(); ChatModelResponse modelListenerPartialResponse = createModelListenerResponse( responseId.get(), @@ -282,7 +286,7 @@ private ChatModelRequest createModelListenerRequest(ChatCompletionRequest reques private ChatModelResponse createModelListenerResponse(String responseId, String responseModel, - Response response) { + ChatResponse response) { if (response == null) { return null; } @@ -292,7 +296,7 @@ private ChatModelResponse createModelListenerResponse(String responseId, .model(responseModel) .tokenUsage(response.tokenUsage()) .finishReason(response.finishReason()) - .aiMessage(response.content()) + .aiMessage(response.aiMessage()) .build(); } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java index 37a8369e9..62aee9700 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -5,6 +5,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; @@ -15,11 +16,13 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.Capability; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.chat.listener.ChatModelRequest; import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.chat.request.ChatRequestParameters; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -278,6 +281,16 @@ public Integer call() throws Exception { }); } + @Override + public ChatRequestParameters defaultRequestParameters() { + return null; + } + + @Override + public Set supportedCapabilities() { + return Set.of(); + } + @Override public Response generate(List messages) { return generate(messages, List.of()); diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java index dd8d6e073..39a07df6c 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxGenerationModel.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; @@ -16,11 +17,13 @@ import dev.langchain4j.data.message.SystemMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.Capability; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TokenCountEstimator; import dev.langchain4j.model.chat.listener.ChatModelRequest; import dev.langchain4j.model.chat.listener.ChatModelResponse; +import dev.langchain4j.model.chat.request.ChatRequestParameters; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -70,6 +73,16 @@ public WatsonxGenerationModel(Builder builder) { .build(); } + @Override + public ChatRequestParameters defaultRequestParameters() { + return null; + } + + @Override + public Set supportedCapabilities() { + return Set.of(); + } + @Override public Response generate(List messages) { diff --git a/pom.xml b/pom.xml index 453ddecad..528be5888 100644 --- a/pom.xml +++ b/pom.xml @@ -33,8 +33,8 @@ 3.15.2 - 0.37.0-SNAPSHOT - 0.36.2 + 1.0.0-alpha1 + 1.0.0-alpha1 1.0.1 2.0.4 3.26.3