Skip to content

Commit

Permalink
- Fix build + rename file to be more explicit + add unit test on Vari…
Browse files Browse the repository at this point in the history
…ableHandler
  • Loading branch information
humcqc committed Jul 24, 2024
1 parent 7d99874 commit 61a4343
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.quarkiverse.langchain4j.test;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.data.AiStatsMessage;
import io.quarkiverse.langchain4j.runtime.aiservice.VariableHandler;
import io.quarkus.test.QuarkusUnitTest;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import static org.assertj.core.api.Assertions.assertThat;

public class VariableHandlerTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class));

@Test
void test_substitution_on_arguments() {
VariableHandler variableHandler = new VariableHandler();

String arguments = "{\"arg0\": 2.0, \"arg1\": $(id1)}";

variableHandler.addVariable("id1", "3.1");

ToolExecutionRequest request = ToolExecutionRequest.builder()
.arguments(arguments)
.build();

ToolExecutionRequest modifiedRequest = variableHandler.substituteVariables(request);

assertThat(modifiedRequest.arguments()).isEqualTo("{\"arg0\": 2.0, \"arg1\": 3.1}");
}

@Test
void test_substitution_on_ai_stats_message() {
VariableHandler variableHandler = new VariableHandler();

String text = "The expected result is $(result1).";

variableHandler.addVariable("result1", "3.1");

AiMessage aiMessage = AiMessage.aiMessage(text);
AiStatsMessage aiStatsMessage = AiStatsMessage.from(aiMessage, new TokenUsage());
AiMessage modifiedMessage = variableHandler.substituteVariables(aiStatsMessage);

assertThat(modifiedMessage.text()).isEqualTo("The expected result is 3.1.");
}

@Test
void test_substitution_on_ai_message() {
VariableHandler variableHandler = new VariableHandler();

String text = "The expected result is $(result1).";

variableHandler.addVariable("result1", "3.1");

AiMessage aiMessage = AiMessage.aiMessage(text);
AiMessage modifiedMessage = variableHandler.substituteVariables(aiMessage);

assertThat(modifiedMessage.text()).isEqualTo(text);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* This class is the equivalent of Langchain4j AiMessage.
* It contains the token usage from the response that produce this AiMessage.
* And add the possibility to update the text in case of text containing Tools Result variables.
* Needed for @{@link io.quarkiverse.langchain4j.runtime.aiservice.ToolsResultMemory}
* Needed for @{@link io.quarkiverse.langchain4j.runtime.aiservice.VariableHandler}
* Example of usage in ExperimentalParallelToolsDelegate in Ollama model provider
*/
public class AiStatsMessage extends AiMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,20 +222,19 @@ public void accept(Response<AiMessage> message) {
break;
}

ChatMemory chatMemory = context.chatMemory(memoryId);
List<ToolExecutionResultMessage> tmpToolExecutionResultMessages = new ArrayList<>();
ToolsResultMemory toolsResultMemory = new ToolsResultMemory();
VariableHandler variableHandler = new VariableHandler();

for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
log.debugv("Attempting to execute tool {0}", toolExecutionRequest);
ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());
if (toolExecutor == null) {
throw runtime("Tool executor %s not found", toolExecutionRequest.name());
}
toolExecutionRequest = toolsResultMemory.substituteArguments(toolExecutionRequest);
toolExecutionRequest = variableHandler.substituteVariables(toolExecutionRequest);
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
log.debugv("Saving result {0} into key {1}", toolExecutionResult, toolExecutionRequest.id());
toolsResultMemory.addVariable(toolExecutionRequest.id(), toolExecutionResult);
variableHandler.addVariable(toolExecutionRequest.id(), toolExecutionResult);
log.debugv("Result of {0} is {1}", toolExecutionRequest, toolExecutionResult);
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(
toolExecutionRequest,
Expand All @@ -247,7 +246,7 @@ public void accept(Response<AiMessage> message) {
}
// In case of tool Execution request we need to update the AiMessage with tools results
// before adding it into chatMemory
aiMessage = toolsResultMemory.substituteAiMessage(aiMessage);
aiMessage = variableHandler.substituteVariables(aiMessage);
if (context.hasChatMemory()) {
chatMemory.add(aiMessage);
tmpToolExecutionResultMessages.forEach(chatMemory::add);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* See usage in @{@link AiServiceMethodImplementationSupport}
*/
@Experimental
public class ToolsResultMemory {
public class VariableHandler {

static Pattern VARIABLE_PATTERN = Pattern.compile("\\$\\((.*?)\\)");

Expand All @@ -28,7 +28,7 @@ public void addVariable(String var, String value) {
variables.put(var, value);
}

public AiMessage substituteAiMessage(AiMessage message) {
public AiMessage substituteVariables(AiMessage message) {
if (message.text() == null) {
return message;
}
Expand All @@ -39,7 +39,7 @@ public AiMessage substituteAiMessage(AiMessage message) {
return message;
}

public ToolExecutionRequest substituteArguments(ToolExecutionRequest toolExecutionRequest) {
public ToolExecutionRequest substituteVariables(ToolExecutionRequest toolExecutionRequest) {
return ToolExecutionRequest.builder()
.id(toolExecutionRequest.id())
.name(toolExecutionRequest.name())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages,
for (ToolResponse toolResponse : toolResponses.actions) {
if (!availableTools.contains(toolResponse.name)) {
throw new RuntimeException(String.format(
"Ollama server wants to call a name '%s' that is not part of the available tools %s",
"Ollama server wants to call '%s' tool that is not part of the available tools %s",
toolResponse.name, availableTools));
} else {
getToolSpecification(toolResponse, toolSpecifications)
Expand Down

0 comments on commit 61a4343

Please sign in to comment.