From e33b92dad8e5ad0b9fa7f42dd628f3dc59fad094 Mon Sep 17 00:00:00 2001 From: humberto Date: Wed, 12 Jun 2024 21:30:23 +0200 Subject: [PATCH] New Prompt, but still some issues with send poem. Need to tune it to avoid unnecessary multiple calls to same tool --- .../langchain4j/deployment/ToolProcessor.java | 2 +- model-providers/ollama/deployment/pom.xml | 5 + .../ollama/deployment/ToolsTest.java | 118 ------------------ .../ollama/tools/Llama3ToolsTest.java | 88 +++++++++++++ .../langchain4j/ollama/tools/Tools.java | 64 ++++++++++ .../src/test/resources/application.properties | 31 +++++ ...Handler.java => AbstractToolsHandler.java} | 102 ++++++++++----- .../langchain4j/ollama/EmptyToolsHandler.java | 2 +- .../ollama/OllamaChatLanguageModel.java | 2 +- .../langchain4j/ollama/ToolsHandler.java | 2 +- .../ollama/ToolsHandlerFactory.java | 7 +- .../ollama/runtime/OllamaRecorder.java | 6 + .../runtime/config/ChatModelConfig.java | 5 + .../toolshandler/Llama3ToolsHandler.java | 30 +++++ 14 files changed, 309 insertions(+), 155 deletions(-) delete mode 100644 model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/ToolsTest.java create mode 100644 model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Llama3ToolsTest.java create mode 100644 model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java create mode 100644 model-providers/ollama/deployment/src/test/resources/application.properties rename model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/{OllamaDefaultToolsHandler.java => AbstractToolsHandler.java} (56%) create mode 100644 model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/toolshandler/Llama3ToolsHandler.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index cdc88f62d..c47b046d1 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -310,7 +310,7 @@ private static String generateInvoker(MethodInfo methodInfo, ClassOutput classOu boolean toolReturnsVoid = methodInfo.returnType().kind() == Type.Kind.VOID; if (toolReturnsVoid) { - invokeMc.returnValue(invokeMc.load("Success")); + invokeMc.returnValue(invokeMc.load("Success")); // TODO: To change } else { invokeMc.returnValue(result); } diff --git a/model-providers/ollama/deployment/pom.xml b/model-providers/ollama/deployment/pom.xml index 8d49e346b..b6543e71e 100644 --- a/model-providers/ollama/deployment/pom.xml +++ b/model-providers/ollama/deployment/pom.xml @@ -36,6 +36,11 @@ quarkus-junit5-internal test + + io.quarkus + quarkus-junit5 + test + org.assertj assertj-core diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/ToolsTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/ToolsTest.java deleted file mode 100644 index 3d6847205..000000000 --- a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/deployment/ToolsTest.java +++ /dev/null @@ -1,118 +0,0 @@ -package io.quarkiverse.langchain4j.ollama.deployment; - -import static org.assertj.core.api.Assertions.assertThat; - -import jakarta.enterprise.context.control.ActivateRequestContext; -import jakarta.inject.Inject; -import jakarta.inject.Singleton; - -import org.jboss.shrinkwrap.api.ShrinkWrap; -import org.jboss.shrinkwrap.api.spec.JavaArchive; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import dev.langchain4j.agent.tool.Tool; -import dev.langchain4j.service.SystemMessage; -import dev.langchain4j.service.UserMessage; -import io.quarkiverse.langchain4j.RegisterAiService; -import io.quarkus.logging.Log; -import io.quarkus.test.QuarkusUnitTest; - -@Disabled("Integration tests that need an ollama server running") -public class ToolsTest { - - @RegisterExtension - static final QuarkusUnitTest unitTest = new QuarkusUnitTest() - .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) - .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.timeout", "60s") - .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-requests", "true") - .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.log-responses", "true") - .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.chat-model.temperature", "0") - .overrideRuntimeConfigKey("quarkus.langchain4j.ollama.experimental-tools", "true"); - - @Singleton - @SuppressWarnings("unused") - static class ExpenseService { - @Tool("useful for when you need to lookup condominium expenses for given dates.") - public String getExpenses(String condominium, String fromDate, String toDate) { - String result = String.format(""" - The Expenses for %s from %s to %s are: - - Expense hp12: 2800e - - Expense 2: 15000e - """, condominium, fromDate, toDate); - Log.infof(result); - return result; - } - } - - @RegisterAiService(tools = ExpenseService.class) - public interface Assistant { - @SystemMessage(""" - You are a property manager assistant, answering to co-owners requests. - Format the date as YYYY-MM-DD and the time as HH:MM - Today is {{current_date}} use this date as date time reference - The co-owners is leaving in the following condominium: {condominium} - """) - @UserMessage(""" - {{request}} - """) - String answer(String condominium, String request); - } - - @Inject - Assistant assistant; - - @Test - @ActivateRequestContext - void test_simple_tool() { - String response = assistant.answer("Rives de Marne", - "What are the expenses for this year ?"); - assertThat(response).contains("Expense hp12"); - } - - @Test - @ActivateRequestContext - void test_should_not_calls_tool() { - String response = assistant.answer("Rives de Marne", "What time is it ?"); - assertThat(response).doesNotContain("Expense hp12"); - } - - @Singleton - @SuppressWarnings("unused") - public static class Calculator { - @Tool("Calculates the length of a string") - String stringLengthStr(String s) { - return String.format("The length of the word %s is %d", s, s.length()); - } - - @Tool("Calculates the sum of two numbers") - String addStr(int a, int b) { - return String.format("The sum of %s and %s is %d", a, b, a + b); - } - - @Tool("Calculates the square root of a number") - String sqrtStr(int x) { - return String.format("The square root of %s is %f", x, Math.sqrt(x)); - } - } - - @RegisterAiService(tools = Calculator.class) - public interface MathAssistant { - String chat(String userMessage); - } - - @Inject - MathAssistant mathAssistant; - - @Test - @ActivateRequestContext - void test_multiple_tools() { - String msg = "What is the square root with maximal precision of the sum of the numbers of letters in the words " + - "\"hello\" and \"world\""; - String response = mathAssistant.chat(msg); - assertThat(response).contains("3.162278"); - - } - -} diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Llama3ToolsTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Llama3ToolsTest.java new file mode 100644 index 000000000..719cdc833 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Llama3ToolsTest.java @@ -0,0 +1,88 @@ +package io.quarkiverse.langchain4j.ollama.tools; + +import dev.langchain4j.service.SystemMessage; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.junit.QuarkusTest; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +@Disabled("Integration tests that need an ollama server running") +@DisplayName("LLM Tools test - " + Llama3ToolsTest.MODEL_NAME) +@QuarkusTest +public class Llama3ToolsTest { + + public static final String MODEL_NAME = "llama3"; + + @RegisterAiService(tools = Tools.Calculator.class, modelName = MODEL_NAME) + public interface MathAssistantLlama3 { + String chat(String userMessage); + } + + @Inject + MathAssistantLlama3 mathAssistantLlama3; + + @Test + @ActivateRequestContext + void square_of_sum_of_number_of_letters() { + String msg = "What is the square root with maximal precision of the sum " + + "of the numbers of letters in the words hello and llama"; + String response = mathAssistantLlama3.chat(msg); + assertThat(response).contains("3.1622776601683795"); + } + + @RegisterAiService(tools = Tools.ExpenseService.class, modelName = MODEL_NAME) + public interface Assistant { + @SystemMessage(""" + You are a property manager assistant, answering to co-owners requests. + Format the date as YYYY-MM-DD and the time as HH:MM + Today is {{current_date}} use this date as date time reference + The co-owners is living in the following condominium: {condominium} + """) + @UserMessage(""" + {{request}} + """) + String answer(String condominium, String request); + } + @Inject + Assistant assistant; + + @Test + @ActivateRequestContext + void get_expenses() { + String response = assistant.answer("Rives de Marne", + "What are the expenses for this year ?"); + assertThat(response).contains("Expense hp12"); + } + + @Test + @ActivateRequestContext + void should_not_calls_tool() { + String response = assistant.answer("Rives de Marne", "What time is it ?"); + assertThat(response).doesNotContain("Expense hp12"); + } + + @RegisterAiService(tools = Tools.EmailService.class, modelName = MODEL_NAME) + public interface PoemService { + @SystemMessage("You are a professional poet") + @UserMessage(""" + Write a poem about {topic}. The poem should be {lines} lines long. Then send this poem by email. + """) + String writeAPoem(String topic, int lines); + } + + @Inject + PoemService poemService; + + @Test + @ActivateRequestContext + void send_a_poem() { + String response = poemService.writeAPoem("Condominium Rives de marne", 4); + assertThat(response).contains("Success"); + } +} diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java new file mode 100644 index 000000000..7f1a0dd05 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java @@ -0,0 +1,64 @@ +package io.quarkiverse.langchain4j.ollama.tools; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkus.logging.Log; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Singleton; + +public class Tools { + + @Singleton + @SuppressWarnings("unused") + public static class Calculator { + @Tool("Calculates the length of a string") + int stringLength(String s) { + return s.length(); + } + + String stringLengthStr(String s) { + return String.format("The length of the word %s is %d", s, s.length()); + } + + @Tool("Calculates the sum of two numbers") + int add(int a, int b) { + return a + b; + } + + String addStr(int a, int b) { + return String.format("The sum of %s and %s is %d", a, b, a + b); + } + + @Tool("Calculates the square root of a number") + double sqrt(int x) { + return Math.sqrt(x); + } + + String sqrtStr(int x) { + return String.format("The square root of %s is %f", x, Math.sqrt(x)); + } + } + + @Singleton + @SuppressWarnings("unused") + static class ExpenseService { + @Tool("useful for when you need to lookup condominium expenses for given dates.") + public String getExpenses(String condominium, String fromDate, String toDate) { + String result = String.format(""" + The Expenses for %s from %s to %s are: + - Expense hp12: 2800e + - Expense 2: 15000e + """, condominium, fromDate, toDate); + Log.infof(result); + return result; + } + } + + @ApplicationScoped + static class EmailService { + @Tool("send the given content by email") + @SuppressWarnings("unused") + public void sendAnEmail(String content) { + Log.info("Tool sendAnEmail has been executed successfully!"); + } + } +} diff --git a/model-providers/ollama/deployment/src/test/resources/application.properties b/model-providers/ollama/deployment/src/test/resources/application.properties new file mode 100644 index 000000000..16f898252 --- /dev/null +++ b/model-providers/ollama/deployment/src/test/resources/application.properties @@ -0,0 +1,31 @@ +quarkus.langchain4j.ollama.log-requests = true +quarkus.langchain4j.ollama.log-responses = true +quarkus.langchain4j.ollama.chat-model.num-predict = 8192 +quarkus.langchain4j.ollama.chat-model.num-ctx = 4096 + + +# Not working llm: calebfahlgren/natural-functions , phi3, aya, mistral, gemma, +# Working llm: llama3, qwen2 +quarkus.langchain4j.ollama.llama3-2048_ctx.chat-model.model-id = llama3-2048_ctx +quarkus.langchain4j.ollama.llama3-2048_ctx.timeout = 60s +quarkus.langchain4j.ollama.llama3-2048_ctx.chat-model.temperature = 0.0 +quarkus.langchain4j.ollama.llama3-2048_ctx.chat-model.num-ctx = 4096 +quarkus.langchain4j.ollama.llama3-2048_ctx.chat-model.num-predict = 8192 +quarkus.langchain4j.ollama.llama3-2048_ctx.experimental-tools = true + +quarkus.langchain4j.ollama.llama3.chat-model.model-id = llama3-2048_ctx +quarkus.langchain4j.ollama.llama3.timeout = 60s +quarkus.langchain4j.ollama.llama3.chat-model.temperature = 0.0 +quarkus.langchain4j.ollama.llama3.chat-model.num-ctx = 4096 +quarkus.langchain4j.ollama.llama3.chat-model.num-predict = 8192 +quarkus.langchain4j.ollama.llama3.experimental-tools = true + +quarkus.langchain4j.ollama.mistral.chat-model.model-id = mistral +quarkus.langchain4j.ollama.mistral.timeout = 60s +quarkus.langchain4j.ollama.mistral.chat-model.temperature = 0.0 +quarkus.langchain4j.ollama.mistral.experimental-tools = true + +quarkus.langchain4j.ollama.qwen2.chat-model.model-id = qwen2 +quarkus.langchain4j.ollama.qwen2.timeout = 60s +quarkus.langchain4j.ollama.qwen2.chat-model.temperature = 0.0 +quarkus.langchain4j.ollama.qwen2.experimental-tools = true \ No newline at end of file diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaDefaultToolsHandler.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/AbstractToolsHandler.java similarity index 56% rename from model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaDefaultToolsHandler.java rename to model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/AbstractToolsHandler.java index 490c52cb0..d5b3eff9f 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaDefaultToolsHandler.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/AbstractToolsHandler.java @@ -1,13 +1,5 @@ package io.quarkiverse.langchain4j.ollama; -import static io.quarkiverse.langchain4j.ollama.ChatRequest.Builder; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; @@ -16,32 +8,28 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; -public class OllamaDefaultToolsHandler implements ToolsHandler { +import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; - static final PromptTemplate DEFAULT_SYSTEM_TEMPLATE = PromptTemplate.from(""" - You have access to the following tools: +import static io.quarkiverse.langchain4j.ollama.ChatRequest.Builder; - {tools} +public abstract class AbstractToolsHandler implements ToolsHandler { - You must always select one of the above tools and respond with a JSON object matching the following schema, - and only this json object: - { - "tool": , - "tool_input": - } - Do not use other tools than the ones from the list above. Always provide the "tool_input" field. - If several tools are necessary, answer them sequentially. + static final Pattern PATTERN1 = Pattern.compile("Tool (.?) called with parameters"); + static final Pattern PATTERN2 = Pattern.compile("called with parameters (.?), do not call it anymore"); - When the user provides sufficient information, answer with the __conversational_response tool. - """); + abstract public PromptTemplate getDefaultSystemTemplate(); static final ToolSpecification DEFAULT_RESPONSE_TOOL = ToolSpecification.builder() .name("__conversational_response") - .description("Respond conversationally if no other tools should be called for a given query and history.") + .description("Respond conversationally if no other tools should be called for a given query and history " + + "or if the user request have been done.") .parameters(ToolParameters.builder() .type("object") .properties( - Map.of("reponse", + Map.of("response", Map.of("type", "string", "description", "Conversational response to the user."))) .required(Collections.singletonList("response")) @@ -61,7 +49,7 @@ public Builder enhanceWithTools(Builder builder, List messages, List extendedList = new ArrayList<>(toolSpecifications.size() + 1); extendedList.addAll(toolSpecifications); extendedList.add(DEFAULT_RESPONSE_TOOL); - Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply( + Prompt prompt = getDefaultSystemTemplate().apply( Map.of("tools", Json.toJson(extendedList))); // TODO handle -> toolThatMustBeExecuted skipped for the moment @@ -75,19 +63,54 @@ public Builder enhanceWithTools(Builder builder, List messages, List otherMessages = messages.stream().filter(cm -> cm.role() != Role.SYSTEM).toList(); - + Message initialUserMessage = Message.builder() + .role(Role.USER) + .content("--- " + otherMessages.get(0).content() + " ---").build(); + + List lastMessages = convertAssistantMessages(otherMessages.subList(1, otherMessages.size())); +// String lastMessagesGrouped = lastMessages.stream() +// .map(Message::content) +// .collect(Collectors.joining("\n")); +// Message lastMessage = Message.builder() +// .role(Role.ASSISTANT) +// .content(lastMessagesGrouped).build(); // Add specific tools message List messagesWithTools = new ArrayList<>(messages.size() + 1); messagesWithTools.add(groupedSystemMessage); - messagesWithTools.addAll(otherMessages); + messagesWithTools.addAll(lastMessages); + messagesWithTools.add(initialUserMessage); builder.messages(messagesWithTools); return builder; } + private List convertAssistantMessages(List lastMessages) { + List messages = new ArrayList<>(); + Message assistantMsg = null; + for (Message msg: lastMessages) { + if (msg.role() == Role.ASSISTANT) { + assistantMsg = msg; + } else if (msg.role() == Role.USER) { + if( assistantMsg == null) { + messages.add(msg); + // throw new RuntimeException(" USER Message detected without corresponding ASSISTANT Message."); + } else { + messages.add(Message.builder() + .role(Role.USER) // Should be ASSISTANT but does not work, to check if not related to num_ctx + .content(String.format("%s and the result is %s", assistantMsg.content(), msg.content())).build()); + assistantMsg = null; + } + } else { + throw new RuntimeException( + String.format("Message with role %s not allowed at this stage.", msg.role())); + } + } + return messages; + } + @Override - public AiMessage getAiMessageFromResponse(ChatResponse response, List toolSpecifications) { + public AiMessage handleResponse(ChatResponse response, List toolSpecifications) { ToolResponse toolResponse; try { // Extract tools @@ -102,16 +125,34 @@ public AiMessage getAiMessageFromResponse(ChatResponse response, List availableTools = toolSpecifications.stream().map(ToolSpecification::name).toList(); if (!availableTools.contains(toolResponse.tool)) { - return AiMessage.from(String.format( + throw new RuntimeException(String.format( "Ollama server wants to call a tool '%s' that is not part of the available tools %s", toolResponse.tool, availableTools)); } // Extract tools request from response List toolExecutionRequests = toToolExecutionRequests(toolResponse, toolSpecifications); - return AiMessage.aiMessage(toolExecutionRequests); + // only one tool will be used. + // toolSpecifications.clear(); + // toolSpecifications.add(DEFAULT_RESPONSE_TOOL); + return new AiMessage(toolResponse.toAiMessageText(), toolExecutionRequests); } record ToolResponse(String tool, Map tool_input) { + + public String toAiMessageText() { + return String.format("Tool \"%s\" with parameters %s has been called", tool, + Json.toJson(tool_input).replace("\n", "")); + } + + public ToolResponse fromAiMessageContent(String content) { + Matcher matcher1 = PATTERN1.matcher(content); + if (matcher1.find()) + { + Map tool_input = Json.fromJson(PATTERN2.matcher(content).group(1), Map.class); + return new ToolResponse(matcher1.group(1), tool_input); + } + return null; + } } private List toToolExecutionRequests(ToolResponse toolResponse, @@ -124,6 +165,7 @@ private List toToolExecutionRequests(ToolResponse toolResp static ToolExecutionRequest toToolExecutionRequest(ToolResponse toolResponse, ToolSpecification toolSpecification) { return ToolExecutionRequest.builder() + .id(UUID.randomUUID().toString()) .name(toolSpecification.name()) .arguments(Json.toJson(toolResponse.tool_input)) .build(); diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmptyToolsHandler.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmptyToolsHandler.java index 7c8a131d6..9677fa1e0 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmptyToolsHandler.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/EmptyToolsHandler.java @@ -16,7 +16,7 @@ public ChatRequest.Builder enhanceWithTools(ChatRequest.Builder requestBuilder, } @Override - public AiMessage getAiMessageFromResponse(ChatResponse response, List toolSpecifications) { + public AiMessage handleResponse(ChatResponse response, List toolSpecifications) { return AiMessage.from(response.message().content()); } } diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java index eaf768f2f..621e44134 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java @@ -65,7 +65,7 @@ private Response generate(List messages, requestBuilder = toolsHandler.enhanceWithTools(requestBuilder, ollamaMessages, toolSpecifications, toolThatMustBeExecuted); ChatResponse response = client.chat(requestBuilder.build()); - AiMessage aiMessage = toolsHandler.getAiMessageFromResponse(response, toolSpecifications); + AiMessage aiMessage = toolsHandler.handleResponse(response, toolSpecifications); return Response.from(aiMessage, new TokenUsage(response.promptEvalCount(), response.evalCount())); } diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java index 79db30986..4e423eff7 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java @@ -25,5 +25,5 @@ ChatRequest.Builder enhanceWithTools(ChatRequest.Builder requestBuilder, List toolSpecifications); + AiMessage handleResponse(ChatResponse response, List toolSpecifications); } diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandlerFactory.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandlerFactory.java index 5f7336638..6bf53153a 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandlerFactory.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandlerFactory.java @@ -1,11 +1,12 @@ package io.quarkiverse.langchain4j.ollama; +import io.quarkiverse.langchain4j.ollama.toolshandler.Llama3ToolsHandler; + public class ToolsHandlerFactory { - private static final ToolsHandler DEFAULT = new OllamaDefaultToolsHandler(); + private static final ToolsHandler LLAMA3 = new Llama3ToolsHandler(); - @SuppressWarnings("unused") public static ToolsHandler get(String model) { - return DEFAULT; + return LLAMA3; } } diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java index 28cb681a2..53d4a3529 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/OllamaRecorder.java @@ -42,6 +42,9 @@ public Supplier chatModel(LangChain4jOllamaConfig runtimeConf if (chatModelConfig.numPredict().isPresent()) { optionsBuilder.numPredict(chatModelConfig.numPredict().getAsInt()); } + if (chatModelConfig.numCtx().isPresent()) { + optionsBuilder.numCtx(chatModelConfig.numCtx().getAsInt()); + } if (chatModelConfig.stop().isPresent()) { optionsBuilder.stop(chatModelConfig.stop().get()); } @@ -132,6 +135,9 @@ public Supplier streamingChatModel(LangChain4jOllama if (chatModelConfig.numPredict().isPresent()) { optionsBuilder.numPredict(chatModelConfig.numPredict().getAsInt()); } + if (chatModelConfig.numCtx().isPresent()) { + optionsBuilder.numCtx(chatModelConfig.numCtx().getAsInt()); + } if (chatModelConfig.stop().isPresent()) { optionsBuilder.stop(chatModelConfig.stop().get()); } diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java index 3817284d8..dd571ff18 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/runtime/config/ChatModelConfig.java @@ -23,6 +23,11 @@ public interface ChatModelConfig { */ OptionalInt numPredict(); + /** + * Maximum number of tokens to keep in the context + */ + OptionalInt numCtx(); + /** * Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return */ diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/toolshandler/Llama3ToolsHandler.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/toolshandler/Llama3ToolsHandler.java new file mode 100644 index 000000000..3a5da3685 --- /dev/null +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/toolshandler/Llama3ToolsHandler.java @@ -0,0 +1,30 @@ +package io.quarkiverse.langchain4j.ollama.toolshandler; + +import dev.langchain4j.model.input.PromptTemplate; +import io.quarkiverse.langchain4j.ollama.AbstractToolsHandler; + +public class Llama3ToolsHandler extends AbstractToolsHandler { + + static final PromptTemplate DEFAULT_SYSTEM_TEMPLATE = PromptTemplate.from(""" + You are a helpful AI assistant responding to user requests delimited by "---". + You have access to the following tools: + {tools} + Select the most appropriate tool for each user request and respond with a JSON object containing: + - "tool": + - "tool_input": + Follow these guidelines: + - Only use the listed tools. + - Avoid using twice the same tool. + - Use user history to avoid selecting the same tool with identical parameters more than once. + - Retrieve precise data using the tools without inventing data or parameters. + - Break down complex requests into sequential tool calls. + - Combine user history and tool descriptions to choose the best next tool. + - If a tool with the same parameters has been used, respond with "__conversational_response" using the previous result. + - When enough information is gathered, respond with "__conversational_response" using the provided data. + """); + + @Override + public PromptTemplate getDefaultSystemTemplate() { + return DEFAULT_SYSTEM_TEMPLATE; + } +}