diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/VariableHandlerTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/VariableHandlerTest.java new file mode 100644 index 000000000..8f34784b9 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/VariableHandlerTest.java @@ -0,0 +1,67 @@ +package io.quarkiverse.langchain4j.test; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkiverse.langchain4j.data.AiStatsMessage; +import io.quarkiverse.langchain4j.runtime.aiservice.VariableHandler; +import io.quarkus.test.QuarkusUnitTest; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static org.assertj.core.api.Assertions.assertThat; + +public class VariableHandlerTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)); + + @Test + void test_substitution_on_arguments() { + VariableHandler variableHandler = new VariableHandler(); + + String arguments = "{\"arg0\": 2.0, \"arg1\": $(id1)}"; + + variableHandler.addVariable("id1", "3.1"); + + ToolExecutionRequest request = ToolExecutionRequest.builder() + .arguments(arguments) + .build(); + + ToolExecutionRequest modifiedRequest = variableHandler.substituteVariables(request); + + assertThat(modifiedRequest.arguments()).isEqualTo("{\"arg0\": 2.0, \"arg1\": 3.1}"); + } + + @Test + void test_substitution_on_ai_stats_message() { + VariableHandler variableHandler = new VariableHandler(); + + String text = "The expected result is $(result1)."; + + variableHandler.addVariable("result1", "3.1"); + + AiMessage aiMessage = AiMessage.aiMessage(text); + AiStatsMessage aiStatsMessage = AiStatsMessage.from(aiMessage, new TokenUsage()); + AiMessage modifiedMessage = variableHandler.substituteVariables(aiStatsMessage); + + assertThat(modifiedMessage.text()).isEqualTo("The expected result is 3.1."); + } + + @Test + void test_substitution_on_ai_message() { + VariableHandler variableHandler = new VariableHandler(); + + String text = "The expected result is $(result1)."; + + variableHandler.addVariable("result1", "3.1"); + + AiMessage aiMessage = AiMessage.aiMessage(text); + AiMessage modifiedMessage = variableHandler.substituteVariables(aiMessage); + + assertThat(modifiedMessage.text()).isEqualTo(text); + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/data/AiStatsMessage.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/data/AiStatsMessage.java index 316c50050..33072df54 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/data/AiStatsMessage.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/data/AiStatsMessage.java @@ -11,7 +11,7 @@ * This class is the equivalent of Langchain4j AiMessage. * It contains the token usage from the response that produce this AiMessage. * And add the possibility to update the text in case of text containing Tools Result variables. - * Needed for @{@link io.quarkiverse.langchain4j.runtime.aiservice.ToolsResultMemory} + * Needed for @{@link io.quarkiverse.langchain4j.runtime.aiservice.VariableHandler} * Example of usage in ExperimentalParallelToolsDelegate in Ollama model provider */ public class AiStatsMessage extends AiMessage { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 689bca887..e058382ee 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -222,9 +222,8 @@ public void accept(Response message) { break; } - ChatMemory chatMemory = context.chatMemory(memoryId); List tmpToolExecutionResultMessages = new ArrayList<>(); - ToolsResultMemory toolsResultMemory = new ToolsResultMemory(); + VariableHandler variableHandler = new VariableHandler(); for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { log.debugv("Attempting to execute tool {0}", toolExecutionRequest); @@ -232,10 +231,10 @@ public void accept(Response message) { if (toolExecutor == null) { throw runtime("Tool executor %s not found", toolExecutionRequest.name()); } - toolExecutionRequest = toolsResultMemory.substituteArguments(toolExecutionRequest); + toolExecutionRequest = variableHandler.substituteVariables(toolExecutionRequest); String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId); log.debugv("Saving result {0} into key {1}", toolExecutionResult, toolExecutionRequest.id()); - toolsResultMemory.addVariable(toolExecutionRequest.id(), toolExecutionResult); + variableHandler.addVariable(toolExecutionRequest.id(), toolExecutionResult); log.debugv("Result of {0} is {1}", toolExecutionRequest, toolExecutionResult); ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from( toolExecutionRequest, @@ -247,7 +246,7 @@ public void accept(Response message) { } // In case of tool Execution request we need to update the AiMessage with tools results // before adding it into chatMemory - aiMessage = toolsResultMemory.substituteAiMessage(aiMessage); + aiMessage = variableHandler.substituteVariables(aiMessage); if (context.hasChatMemory()) { chatMemory.add(aiMessage); tmpToolExecutionResultMessages.forEach(chatMemory::add); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ToolsResultMemory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/VariableHandler.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ToolsResultMemory.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/VariableHandler.java index 4de8315f6..c29c6a6ba 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ToolsResultMemory.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/VariableHandler.java @@ -18,7 +18,7 @@ * See usage in @{@link AiServiceMethodImplementationSupport} */ @Experimental -public class ToolsResultMemory { +public class VariableHandler { static Pattern VARIABLE_PATTERN = Pattern.compile("\\$\\((.*?)\\)"); @@ -28,7 +28,7 @@ public void addVariable(String var, String value) { variables.put(var, value); } - public AiMessage substituteAiMessage(AiMessage message) { + public AiMessage substituteVariables(AiMessage message) { if (message.text() == null) { return message; } @@ -39,7 +39,7 @@ public AiMessage substituteAiMessage(AiMessage message) { return message; } - public ToolExecutionRequest substituteArguments(ToolExecutionRequest toolExecutionRequest) { + public ToolExecutionRequest substituteVariables(ToolExecutionRequest toolExecutionRequest) { return ToolExecutionRequest.builder() .id(toolExecutionRequest.id()) .name(toolExecutionRequest.name()) diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/tool/ExperimentalParallelToolsDelegate.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/tool/ExperimentalParallelToolsDelegate.java index 0bc32ccb8..ed361dd6f 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/tool/ExperimentalParallelToolsDelegate.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/tool/ExperimentalParallelToolsDelegate.java @@ -143,7 +143,7 @@ public Response generate(List messages, for (ToolResponse toolResponse : toolResponses.actions) { if (!availableTools.contains(toolResponse.name)) { throw new RuntimeException(String.format( - "Ollama server wants to call a name '%s' that is not part of the available tools %s", + "Ollama server wants to call '%s' tool that is not part of the available tools %s", toolResponse.name, availableTools)); } else { getToolSpecification(toolResponse, toolSpecifications)