From d06a779a090766f6140a8a952915a3a0620cb4f9 Mon Sep 17 00:00:00 2001 From: humberto Date: Sat, 15 Jun 2024 18:24:10 +0200 Subject: [PATCH] Prompt Optimisation for Llama3 --- .../langchain4j/ollama/tools/Tools.java | 7 +-- .../langchain4j/ollama/tools/ToolsTest.java | 4 +- ...dler.java => ExperimentalToolsChatLM.java} | 54 +++++++++---------- .../ollama/OllamaChatLanguageModel.java | 9 ++-- 4 files changed, 35 insertions(+), 39 deletions(-) rename model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/{ToolsHandler.java => ExperimentalToolsChatLM.java} (76%) diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java index d0a4fc2c3..18c1f6766 100644 --- a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/Tools.java @@ -30,7 +30,7 @@ int add(int a, int b) { @Singleton @SuppressWarnings("unused") static class ExpenseService { - @Tool("get condominium expenses for given dates.") + @Tool("Get expenses for a given condominium, from date and to date.") public String getExpenses(String condominium, String fromDate, String toDate) { String result = String.format(""" The Expenses for %s from %s to %s are: @@ -55,9 +55,4 @@ public void sendAnEmail(String content) { """); } } - - @Tool(name = "__conversational_response", value = "Respond conversationally if no other tools should be called for a given query.") - public String conversation(String response) { - return response; - } } diff --git a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/ToolsTest.java b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/ToolsTest.java index 1d15d72d6..62546f528 100644 --- a/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/ToolsTest.java +++ b/model-providers/ollama/deployment/src/test/java/io/quarkiverse/langchain4j/ollama/tools/ToolsTest.java @@ -43,7 +43,7 @@ public interface Assistant { @SystemMessage(""" You are a property manager assistant, answering to co-owners requests. Format the date as YYYY-MM-DD and the time as HH:MM - Today is {{current_time}} use this date as date time reference + Today is {{current_date_time}} use this date as date time reference The co-owners is living in the following condominium: {condominium} """) @UserMessage(""" @@ -86,7 +86,7 @@ public interface PoemService { @ActivateRequestContext void send_a_poem() { String response = poemService.writeAPoem("Condominium Rives de marne", 4); - assertThat(response).contains("he poem has been sent by email."); + assertThat(response).contains("sent by email"); } @RegisterAiService(modelName = MODEL_NAME) diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ExperimentalToolsChatLM.java similarity index 76% rename from model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java rename to model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ExperimentalToolsChatLM.java index 581976cce..ae6a293cf 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ToolsHandler.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/ExperimentalToolsChatLM.java @@ -20,26 +20,33 @@ import dev.langchain4j.model.output.TokenUsage; import io.quarkus.runtime.annotations.RegisterForReflection; -public class ToolsHandler { +public class ExperimentalToolsChatLM { - private static final Logger log = Logger.getLogger(ToolsHandler.class); + private static final Logger log = Logger.getLogger(ExperimentalToolsChatLM.class); private static final PromptTemplate DEFAULT_SYSTEM_TEMPLATE = PromptTemplate .from(""" - You are a helpful AI assistant responding to user requests. + --- Context --- + {context} + --------------- + + You are a helpful AI assistant responding to user requests taking into account the previous context. You have access to the following tools: {tools} - - Select the most appropriate tools from this list, and respond with a JSON object containing required "tools" and "response" fields: + + Create a list of most appropriate tools to call in order to answer to the user request. + If no tools are required respond with response field directly. + Respond with a JSON object containing required "tools" and required not null "response" fields: - "tools": a list of selected tools in JSON format, each with: - "name": - - "inputs": + - "inputs": - "result_id": - - "response": - + - "response": < Summary of tools used with your response using tools result_id> + Guidelines: - - Reference previous tools results using the format: $(xxx), where xxx is a result_id. + - Only reference previous tools results using the format: $(xxx), where xxx is a previous result_id. - Break down complex requests into sequential and necessary tools. + - Use previous results through result_id for inputs response, do not invent them. """); @RegisterForReflection @@ -53,26 +60,21 @@ record ToolResponse(String name, Map inputs, String result_id) { public Response chat(OllamaClient client, Builder builder, List messages, List toolSpecifications, ToolSpecification toolThatMustBeExecuted) { - // Test if it's an AI Service request with tool results + // Test if it's an AI request with tools execution response. boolean hasResultMessages = messages.stream().anyMatch(m -> m.role() == Role.TOOL_EXECUTION_RESULT); if (hasResultMessages) { String result = messages.stream().filter(term -> term.role() == Role.ASSISTANT) .map(Message::content).collect(Collectors.joining("\n")); return Response.from(AiMessage.from(result)); } - + // Creates Chat request builder.format("json"); - Message groupedSystemMessage = createSystemMessageWithTools(messages, toolSpecifications); + Message systemMessage = createSystemMessageWithTools(messages, toolSpecifications); List otherMessages = messages.stream().filter(cm -> cm.role() != Role.SYSTEM).toList(); - Message initialUserMessage = Message.builder() - .role(Role.USER) - .content("--- " + otherMessages.get(0).content() + " ---").build(); - - List messagesWithTools = new ArrayList<>(messages.size() + 1); - messagesWithTools.add(groupedSystemMessage); - messagesWithTools.addAll(otherMessages.subList(1, otherMessages.size())); - messagesWithTools.add(initialUserMessage); + List messagesWithTools = new ArrayList<>(otherMessages.size() + 1); + messagesWithTools.add(systemMessage); + messagesWithTools.addAll(otherMessages); builder.messages(messagesWithTools); @@ -82,18 +84,16 @@ public Response chat(OllamaClient client, Builder builder, List messages, List toolSpecifications) { - Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply( - Map.of("tools", Json.toJson(toolSpecifications))); - String initialSystemMessages = messages.stream().filter(sm -> sm.role() == Role.SYSTEM) .map(Message::content) .collect(Collectors.joining("\n")); - + Prompt prompt = DEFAULT_SYSTEM_TEMPLATE.apply( + Map.of("tools", Json.toJson(toolSpecifications), + "context", initialSystemMessages)); return Message.builder() .role(Role.SYSTEM) - .content(prompt.text() + "\n" + initialSystemMessages) + .content(prompt.text()) .build(); - } private AiMessage handleResponse(ChatResponse response, List toolSpecifications) { @@ -123,7 +123,7 @@ private AiMessage handleResponse(ChatResponse response, List } } - if (toolResponses.response != null) { + if (toolResponses.response != null && !toolResponses.response().isEmpty()) { return new AiMessage(toolResponses.response, toolExecutionRequests); } return AiMessage.from(toolExecutionRequests); diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java index 67d4d630c..77d186588 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/OllamaChatLanguageModel.java @@ -20,14 +20,14 @@ public class OllamaChatLanguageModel implements ChatLanguageModel { private final String model; private final String format; private final Options options; - private final ToolsHandler toolsHandler; + private final ExperimentalToolsChatLM experimentalToolsChatLM; private OllamaChatLanguageModel(Builder builder) { client = new OllamaClient(builder.baseUrl, builder.timeout, builder.logRequests, builder.logResponses); model = builder.model; format = builder.format; options = builder.options; - toolsHandler = builder.experimentalTool ? new ToolsHandler() : null; + experimentalToolsChatLM = builder.experimentalTool ? new ExperimentalToolsChatLM() : null; } public static Builder builder() { @@ -61,9 +61,10 @@ private Response generate(List messages, .options(options) .format(format) .stream(false); - boolean isToolNeeded = toolsHandler != null && toolSpecifications != null && !toolSpecifications.isEmpty(); + boolean isToolNeeded = experimentalToolsChatLM != null && toolSpecifications != null && !toolSpecifications.isEmpty(); if (isToolNeeded) { - return toolsHandler.chat(client, requestBuilder, ollamaMessages, toolSpecifications, toolThatMustBeExecuted); + return experimentalToolsChatLM.chat(client, requestBuilder, ollamaMessages, toolSpecifications, + toolThatMustBeExecuted); } else { ChatResponse response = client.chat(requestBuilder.build()); AiMessage aiMessage = AiMessage.from(response.message().content());