diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java index 35031d1..5be92a0 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java @@ -10,9 +10,7 @@ import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; import com.meta.cp4m.Identifier; import com.meta.cp4m.message.FBMessage; import com.meta.cp4m.message.Message; @@ -27,14 +25,11 @@ import java.time.Instant; import java.util.*; -import org.checkerframework.common.returnsreceiver.qual.This; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class HuggingFaceLlamaPrompt { - private static final Logger LOGGER = LoggerFactory.getLogger(HuggingFaceLlamaPrompt.class); - private static final ObjectMapper MAPPER = new ObjectMapper(); private final String systemMessage; private final long maxInputTokens; private final HuggingFaceTokenizer tokenizer; @@ -108,58 +103,7 @@ private int tokenCount(String message) { return encoding.getTokens().length - 1; } - // TODO: move logic into promptbuilder - private String pruneMessages(ThreadState threadState) { - - int totalTokens = 5; // Account for closing tokens at end of message - StringBuilder promptStringBuilder = new StringBuilder(); - String systemPrompt = "[INST] <>\n" + systemMessage + "\n<>\n\n"; - totalTokens += tokenCount(systemPrompt); - promptStringBuilder - .append("[INST] <>\n") - .append(systemMessage) - .append("\n<>\n\n"); - - Message.Role nextMessageSender = Message.Role.ASSISTANT; - StringBuilder contextStringBuilder = new StringBuilder(); - - List messages = threadState.messages(); - - for (int i = messages.size() - 1; i >= 0; i--) { - Message message = messages.get(i); - StringBuilder messageText = new StringBuilder(); - String text = message.message().strip(); - Message.Role user = message.role(); - boolean isUser = user == Message.Role.USER; - messageText.append(text); - if (isUser && nextMessageSender == Message.Role.ASSISTANT) { - messageText.append(" [/INST] "); - } else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) { - messageText.append(" [INST] "); - } - totalTokens += tokenCount(messageText.toString()); - if (totalTokens > maxInputTokens) { - if (contextStringBuilder.isEmpty()) { - return "I'm sorry but that request was too long for me."; - } - break; - } - contextStringBuilder.append(messageText.reverse()); - - nextMessageSender = user; - } - if (nextMessageSender == Message.Role.ASSISTANT) { - contextStringBuilder.append( - " ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after - // system prompt is not from user - } - - promptStringBuilder.append(contextStringBuilder.reverse()); - return promptStringBuilder.toString().strip(); - } - - // TODO: convert this to a class and implement the methods to replace pruneMethod - private class PromptBuilder { + private static class PromptBuilder { int totalTokens = 5; StringBuilder promptStringBuilder = new StringBuilder(); diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java deleted file mode 100644 index fe7ae10..0000000 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPromptBuilder.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -package com.meta.cp4m.llm; - -import ai.djl.huggingface.tokenizers.Encoding; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.meta.cp4m.message.Message; -import com.meta.cp4m.message.ThreadState; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.beans.beancontext.BeanContextChild; -import java.io.IOException; -import java.nio.file.Paths; -import java.util.*; -import java.net.URI; -import java.net.URISyntaxException; - - -public class HuggingFaceLlamaPromptBuilder { - - private static final Logger LOGGER = LoggerFactory.getLogger(HuggingFaceLlamaPromptBuilder.class); - - public String createPrompt(ThreadState threadState, HuggingFaceConfig config) { - - - URI resource = null; - try { - resource = Objects.requireNonNull(HuggingFaceLlamaPromptBuilder.class.getClassLoader().getResource("llamaTokenizer.json")).toURI(); - } catch (URISyntaxException e) { - LOGGER.error("Failed to find local llama tokenizer.json file", e); - } - - try { - assert resource != null; - HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(resource)); - return pruneMessages(threadState, config, tokenizer); - } catch (IOException e) { - LOGGER.error("Failed to initialize Llama2 tokenizer from local file", e); - } - - return "[INST] <>\n" + (config.systemMessage()) + "\n<>\n\n" + threadState.messages().get(threadState.messages().size() - 1) + " [/INST] "; - - } - - private int tokenCount(String message, HuggingFaceTokenizer tokenizer) { - Encoding encoding = tokenizer.encode(message); - return encoding.getTokens().length - 1; - } - - private String pruneMessages(ThreadState threadState, HuggingFaceConfig config, HuggingFaceTokenizer tokenizer) - throws JsonProcessingException { - - int totalTokens = 5; // Account for closing tokens at end of message - StringBuilder promptStringBuilder = new StringBuilder(); - - String systemPrompt = "[INST] <>\n" + config.systemMessage() + "\n<>\n\n"; - totalTokens += tokenCount(systemPrompt, tokenizer); - promptStringBuilder.append("[INST] <>\n").append(config.systemMessage()).append("\n<>\n\n"); - - - Message.Role nextMessageSender = Message.Role.ASSISTANT; - StringBuilder contextStringBuilder = new StringBuilder(); - - List messages = threadState.messages(); - - for (int i = messages.size() - 1; i >= 0; i--) { - Message message = messages.get(i); - StringBuilder messageText = new StringBuilder(); - String text = message.message().strip(); - Message.Role user = message.role(); - boolean isUser = user == Message.Role.USER; - messageText.append(text); - if (isUser && nextMessageSender == Message.Role.ASSISTANT) { - messageText.append(" [/INST] "); - } else if (user == Message.Role.ASSISTANT && nextMessageSender == Message.Role.USER) { - messageText.append(" [INST] "); - } - totalTokens += tokenCount(messageText.toString(), tokenizer); - if (totalTokens > config.maxInputTokens()) { - if (contextStringBuilder.isEmpty()) { - return "I'm sorry but that request was too long for me."; - } - break; - } - contextStringBuilder.append(messageText.reverse()); - - nextMessageSender = user; - } - if (nextMessageSender == Message.Role.ASSISTANT) { - contextStringBuilder.append(" ]TSNI/[ "); // Reversed [/INST] to close instructions for when first message after system prompt is not from user - } - - promptStringBuilder.append(contextStringBuilder.reverse()); - return promptStringBuilder.toString().strip(); - } -}