From b048c6f1714b71c33084960024c1664d9fa0a32c Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Tue, 12 Nov 2024 11:14:46 +0100 Subject: [PATCH 1/2] Polish watsonx code --- .../watsonx/bean/TextChatMessage.java | 13 +++-- .../watsonx/client/WatsonxRestApi.java | 24 ++++++++- .../filter/BearerTokenHeaderFactory.java | 2 +- .../watsonx/runtime/WatsonxRecorder.java | 51 ++++++++----------- 4 files changed, 50 insertions(+), 40 deletions(-) diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java index 091d86f36..487cdc35a 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.watsonx.bean; import static io.quarkiverse.langchain4j.watsonx.WatsonxUtils.base64Image; +import static java.util.Objects.isNull; import java.util.ArrayList; import java.util.List; @@ -131,8 +132,9 @@ public static TextChatMessageUser of(UserMessage userMessage) { } case IMAGE -> { var imageContent = ImageContent.class.cast(content); - var base64 = "data:image/%s;base64,%s".formatted( - imageContent.image().mimeType(), + var mimeType = imageContent.image().mimeType(); + var base64 = "data:%s;base64,%s".formatted( + isNull(mimeType) ? "image" : mimeType, base64Image(imageContent.image())); values.add(Map.of( "type", "image_url", @@ -225,11 +227,8 @@ public record TextChatParameterFunction(String name, String description, Map embeddingModel(LangChain4jWatsonxConfig runtimeC throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); } - String iamUrl = watsonConfig.iam().baseUrl().toExternalForm(); - WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, - createTokenGenerator(watsonConfig.iam(), - firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()))); + String apiKey = firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); URL url; try { @@ -170,7 +167,7 @@ public Supplier embeddingModel(LangChain4jWatsonxConfig runtimeC EmbeddingModelConfig embeddingModelConfig = watsonConfig.embeddingModel(); var builder = WatsonxEmbeddingModel.builder() - .tokenGenerator(tokenGenerator) + .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), watsonConfig.logRequests())) @@ -209,10 +206,7 @@ public Supplier scoringModel(LangChain4jWatsonxConfig runtimeConfi throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS)); } - String iamUrl = watsonConfig.iam().baseUrl().toExternalForm(); - WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, - createTokenGenerator(watsonConfig.iam(), - firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()))); + String apiKey = firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); URL url; try { @@ -223,7 +217,7 @@ public Supplier scoringModel(LangChain4jWatsonxConfig runtimeConfi ScoringModelConfig rerankModelConfig = watsonConfig.scoringModel(); var builder = WatsonxScoringModel.builder() - .tokenGenerator(tokenGenerator) + .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, rerankModelConfig.logRequests(), watsonConfig.logRequests())) @@ -242,17 +236,6 @@ public WatsonxScoringModel get() { }; } - private Function createTokenGenerator(IAMConfig iamConfig, String apiKey) { - return new Function() { - - @Override - public WatsonxTokenGenerator apply(String iamUrl) { - return new WatsonxTokenGenerator(iamConfig.baseUrl(), iamConfig.timeout().orElse(Duration.ofSeconds(10)), - iamConfig.grantType(), apiKey); - } - }; - } - private WatsonxChatModel.Builder chatBuilder(LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName); @@ -263,10 +246,7 @@ private WatsonxChatModel.Builder chatBuilder(LangChain4jWatsonxConfig runtimeCon } ChatModelConfig chatModelConfig = watsonConfig.chatModel(); - String iamUrl = watsonConfig.iam().baseUrl().toExternalForm(); - WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, - createTokenGenerator(watsonConfig.iam(), - firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()))); + String apiKey = firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); URL url; try { @@ -276,7 +256,7 @@ private WatsonxChatModel.Builder chatBuilder(LangChain4jWatsonxConfig runtimeCon } return WatsonxChatModel.builder() - .tokenGenerator(tokenGenerator) + .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), watsonConfig.logRequests())) @@ -306,10 +286,7 @@ private WatsonxGenerationModel.Builder generationBuilder(LangChain4jWatsonxConfi } GenerationModelConfig generationModelConfig = watsonConfig.generationModel(); - String iamUrl = watsonConfig.iam().baseUrl().toExternalForm(); - WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl, - createTokenGenerator(watsonConfig.iam(), - firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()))); + String apiKey = firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey()); URL url; try { @@ -323,7 +300,7 @@ private WatsonxGenerationModel.Builder generationBuilder(LangChain4jWatsonxConfi String promptJoiner = generationModelConfig.promptJoiner(); return WatsonxGenerationModel.builder() - .tokenGenerator(tokenGenerator) + .tokenGenerator(createTokenGenerator(watsonConfig.iam(), apiKey)) .url(url) .timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10))) .logRequests(firstOrDefault(false, generationModelConfig.logRequests(), watsonConfig.logRequests())) @@ -348,6 +325,18 @@ private WatsonxGenerationModel.Builder generationBuilder(LangChain4jWatsonxConfi .promptJoiner(promptJoiner); } + private WatsonxTokenGenerator createTokenGenerator(IAMConfig iamConfig, String apiKey) { + return tokenGeneratorCache.computeIfAbsent(apiKey, + new Function() { + @Override + public WatsonxTokenGenerator apply(String iamUrl) { + return new WatsonxTokenGenerator(iamConfig.baseUrl(), + iamConfig.timeout().orElse(Duration.ofSeconds(10)), + iamConfig.grantType(), apiKey); + } + }); + } + private LangChain4jWatsonxConfig.WatsonConfig correspondingWatsonRuntimeConfig(LangChain4jWatsonxConfig runtimeConfig, String configName) { LangChain4jWatsonxConfig.WatsonConfig watsonConfig; From cb6a4da0e64da85b816b4fc3cd110ca0595f6a15 Mon Sep 17 00:00:00 2001 From: Andrea Di Maio Date: Sat, 30 Nov 2024 18:21:37 +0100 Subject: [PATCH 2/2] Enable tools support in streaming responses for Ollama --- ...maStreamingChatLanguageModelSmokeTest.java | 281 ++++++++++++++++++ .../ollama/OllamaChatLanguageModel.java | 27 +- .../OllamaStreamingChatLanguageModel.java | 79 ++++- .../langchain4j/ollama/ToolCall.java | 17 ++ 4 files changed, 367 insertions(+), 37 deletions(-) create mode 100644 model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStreamingChatLanguageModelSmokeTest.java diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStreamingChatLanguageModelSmokeTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStreamingChatLanguageModelSmokeTest.java new file mode 100644 index 000000000..1405ba009 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/OllamaStreamingChatLanguageModelSmokeTest.java @@ -0,0 +1,281 @@ +package io.quarkiverse.langchain4j.ollama.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.stubbing.Scenario; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.mutiny.Multi; + +public class OllamaStreamingChatLanguageModelSmokeTest extends WiremockAware { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(Calculator.class)) + .overrideConfigKey("quarkus.langchain4j.ollama.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideConfigKey("quarkus.langchain4j.devservices.enabled", "false"); + + @Singleton + @RegisterAiService(tools = Calculator.class) + interface AIServiceWithTool { + Multi streaming(@MemoryId String memoryId, @dev.langchain4j.service.UserMessage String text); + } + + @Singleton + @RegisterAiService + interface AIServiceWithoutTool { + Multi streaming(@dev.langchain4j.service.UserMessage String text); + } + + @Singleton + static class Calculator { + @Tool("Execute the sum of two numbers") + public int sum(int firstNumber, int secondNumber) { + return firstNumber + secondNumber; + } + } + + @Inject + AIServiceWithTool aiServiceWithTool; + + @Inject + AIServiceWithoutTool aiServiceWithoutTool; + + @Inject + ChatMemoryStore memory; + + @Test + void test_1() { + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(equalToJson(""" + { + "model" : "llama3.2", + "messages" : [ { + "role" : "user", + "content" : "Hello" + }], + "options" : { + "temperature" : 0.8, + "top_k" : 40, + "top_p" : 0.9 + }, + "stream" : true + } + """)) + .willReturn(aResponse() + .withHeader("Content-Type", "application/x-ndjson") + .withBody( + """ + {"model":"llama3.2","created_at":"2024-11-30T09:03:42.312611426Z","message":{"role":"assistant","content":"Hello"},"done":false} + {"model":"llama3.2","created_at":"2024-11-30T09:03:42.514215351Z","message":{"role":"assistant","content":"!"},"done":false} + {"model":"llama3.2","created_at":"2024-11-30T09:03:44.109059873Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":4821417857,"load_duration":2508844071,"prompt_eval_count":11,"prompt_eval_duration":514000000,"eval_count":10,"eval_duration":1797000000}"""))); + + var result = aiServiceWithoutTool.streaming("Hello").collect().asList().await().indefinitely(); + assertEquals(List.of("Hello", "!"), result); + } + + @Test + void test_2() { + wiremock().register( + post(urlEqualTo("/api/chat")) + .withRequestBody(equalToJson(""" + { + "model" : "llama3.2", + "messages" : [ { + "role" : "user", + "content" : "Hello" + }], + "tools" : [ { + "type" : "function", + "function" : { + "name" : "sum", + "description" : "Execute the sum of two numbers", + "parameters" : { + "type" : "object", + "properties" : { + "firstNumber" : { + "type" : "integer" + }, + "secondNumber" : { + "type" : "integer" + } + }, + "required" : [ "firstNumber", "secondNumber" ] + } + } + } ], + "options" : { + "temperature" : 0.8, + "top_k" : 40, + "top_p" : 0.9 + }, + "stream" : true + } + """)) + .willReturn(aResponse() + .withHeader("Content-Type", "application/x-ndjson") + .withBody( + """ + {"model":"llama3.2","created_at":"2024-11-30T09:03:42.312611426Z","message":{"role":"assistant","content":"Hello"},"done":false} + {"model":"llama3.2","created_at":"2024-11-30T09:03:42.514215351Z","message":{"role":"assistant","content":"!"},"done":false} + {"model":"llama3.2","created_at":"2024-11-30T09:03:44.109059873Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":4821417857,"load_duration":2508844071,"prompt_eval_count":11,"prompt_eval_duration":514000000,"eval_count":10,"eval_duration":1797000000}"""))); + + var result = aiServiceWithTool.streaming("1", "Hello").collect().asList().await().indefinitely(); + assertEquals(List.of("Hello", "!"), result); + } + + @Test + void test_3() { + wiremock() + .register( + post(urlEqualTo("/api/chat")) + .inScenario("") + .whenScenarioStateIs(Scenario.STARTED) + .willSetStateTo("TOOL_CALL") + .withRequestBody(equalToJson(""" + { + "model" : "llama3.2", + "messages" : [ { + "role" : "user", + "content" : "1 + 1" + }], + "tools" : [ { + "type" : "function", + "function" : { + "name" : "sum", + "description" : "Execute the sum of two numbers", + "parameters" : { + "type" : "object", + "properties" : { + "firstNumber" : { + "type" : "integer" + }, + "secondNumber" : { + "type" : "integer" + } + }, + "required" : [ "firstNumber", "secondNumber" ] + } + } + } ], + "options" : { + "temperature" : 0.8, + "top_k" : 40, + "top_p" : 0.9 + }, + "stream" : true + } + """)) + + .willReturn(aResponse() + .withHeader("Content-Type", "application/x-ndjson") + .withBody( + """ + {"model":"llama3.1","created_at":"2024-11-30T16:36:02.833930413Z","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"sum","arguments":{"firstNumber":1,"secondNumber":1}}}]},"done":false} + {"model":"llama3.1","created_at":"2024-11-30T16:36:04.368016152Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":28825672145,"load_duration":29961281,"prompt_eval_count":169,"prompt_eval_duration":3906000000,"eval_count":22,"eval_duration":24887000000}"""))); + + wiremock() + .register( + post(urlEqualTo("/api/chat")) + .inScenario("") + .whenScenarioStateIs("TOOL_CALL") + .willSetStateTo("AI_RESPONSE") + .withRequestBody(equalToJson(""" + { + "model" : "llama3.2", + "messages" : [ { + "role" : "user", + "content" : "1 + 1" + }, { + "role" : "assistant", + "tool_calls" : [ { + "function" : { + "name" : "sum", + "arguments" : { + "firstNumber" : 1, + "secondNumber" : 1 + } + } + } ] + }, { + "role" : "tool", + "content" : "2" + } ], + "tools" : [ { + "type" : "function", + "function" : { + "name" : "sum", + "description" : "Execute the sum of two numbers", + "parameters" : { + "type" : "object", + "properties" : { + "firstNumber" : { + "type" : "integer" + }, + "secondNumber" : { + "type" : "integer" + } + }, + "required" : [ "firstNumber", "secondNumber" ] + } + } + } ], + "options" : { + "temperature" : 0.8, + "top_k" : 40, + "top_p" : 0.9 + }, + "stream" : true + } + """)) + .willReturn(aResponse() + .withHeader("Content-Type", "application/x-ndjson") + .withBody( + """ + {"model":"llama3.1","created_at":"2024-11-30T16:36:04.368016152Z","message":{"role":"assistant","content":"The result is 2"},"done_reason":"stop","done":true,"total_duration":28825672145,"load_duration":29961281,"prompt_eval_count":169,"prompt_eval_duration":3906000000,"eval_count":22,"eval_duration":24887000000}"""))); + + var result = aiServiceWithTool.streaming("2", "1 + 1").collect().asList().await().indefinitely(); + assertEquals(List.of("The result is 2"), result); + + var messages = memory.getMessages("2"); + assertEquals("1 + 1", ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText()); + assertEquals("The result is 2", ((dev.langchain4j.data.message.AiMessage) messages.get(3)).text()); + + if (messages.get(1) instanceof AiMessage aiMessage) { + assertTrue(aiMessage.hasToolExecutionRequests()); + assertEquals("{\"firstNumber\":1,\"secondNumber\":1}", aiMessage.toolExecutionRequests().get(0).arguments()); + } else { + fail("The second message is not of type AiMessage"); + } + + if (messages.get(2) instanceof ToolExecutionResultMessage toolResultMessage) { + assertEquals(2, Integer.parseInt(toolResultMessage.text())); + } else { + fail("The third message is not of type ToolExecutionResultMessage"); + } + } +} diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java index e8ac307ea..471c951a3 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java @@ -6,7 +6,6 @@ import static io.quarkiverse.langchain4j.ollama.MessageMapper.toTools; import java.time.Duration; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -14,8 +13,6 @@ import org.jboss.logging.Logger; -import com.fasterxml.jackson.core.JsonProcessingException; - import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; @@ -29,7 +26,6 @@ import dev.langchain4j.model.chat.listener.ChatModelResponseContext; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; -import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; public class OllamaChatLanguageModel implements ChatLanguageModel { @@ -137,25 +133,10 @@ private static Response toResponse(ChatResponse response) { AiMessage.from(response.message().content()), new TokenUsage(response.promptEvalCount(), response.evalCount())); } else { - try { - List toolExecutionRequests = new ArrayList<>(toolCalls.size()); - for (ToolCall toolCall : toolCalls) { - ToolCall.FunctionCall functionCall = toolCall.function(); - - // TODO: we need to update LangChain4j to make ToolExecutionRequest use a map instead of a String - String argumentsStr = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER - .writeValueAsString(functionCall.arguments()); - toolExecutionRequests.add(ToolExecutionRequest.builder() - .name(functionCall.name()) - .arguments(argumentsStr) - .build()); - } - - result = Response.from(aiMessage(toolExecutionRequests), - new TokenUsage(response.promptEvalCount(), response.evalCount())); - } catch (JsonProcessingException e) { - throw new RuntimeException("Unable to parse tool call response", e); - } + List toolExecutionRequests = toolCalls.stream().map(ToolCall::toToolExecutionRequest) + .toList(); + result = Response.from(aiMessage(toolExecutionRequests), + new TokenUsage(response.promptEvalCount(), response.evalCount())); } return result; } diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java index e10d273fe..eca9e2635 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaStreamingChatLanguageModel.java @@ -2,23 +2,31 @@ import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static io.quarkiverse.langchain4j.ollama.MessageMapper.toOllamaMessages; +import static io.quarkiverse.langchain4j.ollama.MessageMapper.toTools; import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +import java.util.stream.Collectors; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.StreamingResponseHandler; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; import io.smallrye.mutiny.Context; /** * Use to have streaming feature on models used trough Ollama. */ public class OllamaStreamingChatLanguageModel implements StreamingChatLanguageModel { + private static final String TOOLS_CONTEXT = "TOOLS"; + private static final String TOKEN_USAGE_CONTEXT = "TOKEN_USAGE"; + private static final String RESPONSE_CONTEXT = "RESPONSE"; private final OllamaClient client; private final String model; private final String format; @@ -37,18 +45,23 @@ public static OllamaStreamingChatLanguageModel.Builder builder() { } @Override - public void generate(List messages, StreamingResponseHandler handler) { + public void generate(List messages, List toolSpecifications, + StreamingResponseHandler handler) { ensureNotEmpty(messages, "messages"); + var tools = (toolSpecifications != null && toolSpecifications.size() > 0) ? toTools(toolSpecifications) : null; ChatRequest request = ChatRequest.builder() .model(model) .messages(toOllamaMessages(messages)) .options(options) .format(format) + .tools(tools) .stream(true) .build(); - Context context = Context.of("response", new ArrayList()); + Context context = Context.empty(); + context.put(RESPONSE_CONTEXT, new ArrayList()); + context.put(TOOLS_CONTEXT, new ArrayList()); client.streamingChat(request) .subscribe() @@ -58,13 +71,31 @@ public void generate(List messages, StreamingResponseHandler) context.get("response")).add(response); - handler.onNext(response.message().content()); + + if (response.message().toolCalls() != null) { + List toolContext = context.get(TOOLS_CONTEXT); + List toolCalls = response.message().toolCalls(); + toolCalls.stream() + .map(ToolCall::toToolExecutionRequest) + .forEach(toolContext::add); + } + + if (!response.message().content().isEmpty()) { + ((List) context.get(RESPONSE_CONTEXT)).add(response); + handler.onNext(response.message().content()); + } + + if (response.done()) { + TokenUsage tokenUsage = new TokenUsage( + response.evalCount(), + response.promptEvalCount(), + response.evalCount() + response.promptEvalCount()); + context.put(TOKEN_USAGE_CONTEXT, tokenUsage); + } + } catch (Exception e) { handler.onError(e); } @@ -78,19 +109,39 @@ public void accept(Throwable error) { }, new Runnable() { @Override - @SuppressWarnings("unchecked") public void run() { - var list = ((List) context.get("response")); - StringBuilder builder = new StringBuilder(); - for (ChatResponse response : list) { - builder.append(response.message().content()); + + TokenUsage tokenUsage = context.get(TOKEN_USAGE_CONTEXT); + List chatResponses = context.get(RESPONSE_CONTEXT); + List toolExecutionRequests = context.get(TOOLS_CONTEXT); + + if (toolExecutionRequests.size() > 0) { + handler.onComplete(Response.from(AiMessage.from(toolExecutionRequests), tokenUsage)); + return; } - AiMessage message = new AiMessage(builder.toString()); - handler.onComplete(Response.from(message)); + + String response = chatResponses.stream() + .map(ChatResponse::message) + .map(Message::content) + .collect(Collectors.joining()); + + AiMessage message = new AiMessage(response); + handler.onComplete(Response.from(message, tokenUsage)); } }); } + @Override + public void generate(List messages, ToolSpecification toolSpecification, + StreamingResponseHandler handler) { + generate(messages, List.of(toolSpecification), handler); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + generate(messages, List.of(), handler); + } + /** * Builder for Ollama configuration. */ diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolCall.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolCall.java index 0ef640a1b..080b6dcc1 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolCall.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolCall.java @@ -2,12 +2,29 @@ import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; + public record ToolCall(FunctionCall function) { public static ToolCall fromFunctionCall(String name, Map arguments) { return new ToolCall(new FunctionCall(name, arguments)); } + public ToolExecutionRequest toToolExecutionRequest() { + try { + return ToolExecutionRequest.builder() + .name(function.name) + .arguments(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER + .writeValueAsString(function.arguments())) + .build(); + } catch (JsonProcessingException e) { + throw new RuntimeException("Unable to parse tool call response", e); + } + } + public record FunctionCall(String name, Map arguments) { }