From 6342b3d62bb7d0f4911229a6e4fe00a96df54bc6 Mon Sep 17 00:00:00 2001 From: colin <98445953+colinmccloskey@users.noreply.github.com> Date: Thu, 2 Nov 2023 18:37:20 -0400 Subject: [PATCH] Addressing comments --- .../meta/cp4m/llm/HuggingFaceLlamaPlugin.java | 2 +- .../meta/cp4m/llm/HuggingFaceLlamaPrompt.java | 25 +++++++++---------- .../cp4m/llm/HuggingFaceLlamaPluginTest.java | 10 ++++---- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java index 85affba..126d511 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java @@ -35,7 +35,7 @@ public class HuggingFaceLlamaPlugin implements LLMPlugin { public HuggingFaceLlamaPlugin(HuggingFaceConfig config) { this.config = config; this.endpoint = this.config.endpoint(); - promptCreator = new HuggingFaceLlamaPrompt<>(config); + promptCreator = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens()); } @Override diff --git a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java index 0a46f7e..c244ec1 100644 --- a/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java +++ b/src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java @@ -10,10 +10,7 @@ import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import com.meta.cp4m.Identifier; -import com.meta.cp4m.message.FBMessage; import com.meta.cp4m.message.Message; -import com.meta.cp4m.message.MessageFactory; import com.meta.cp4m.message.ThreadState; import java.io.IOException; @@ -21,7 +18,6 @@ import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Paths; -import java.time.Instant; import java.util.*; public class HuggingFaceLlamaPrompt { @@ -30,10 +26,10 @@ public class HuggingFaceLlamaPrompt { private final long maxInputTokens; private final HuggingFaceTokenizer tokenizer; - public HuggingFaceLlamaPrompt(HuggingFaceConfig config) { + public HuggingFaceLlamaPrompt(String systemMessage, long maxInputTokens) { - this.systemMessage = config.systemMessage(); - this.maxInputTokens = config.maxInputTokens(); + this.systemMessage = systemMessage; + this.maxInputTokens = maxInputTokens; URL llamaTokenizerUrl = Objects.requireNonNull( HuggingFaceLlamaPrompt.class.getClassLoader().getResource("llamaTokenizer.json")); @@ -51,7 +47,7 @@ public HuggingFaceLlamaPrompt(HuggingFaceConfig config) { public Optional createPrompt(ThreadState threadState) { PromptBuilder builder = new PromptBuilder(); - + int totalTokens = tokenCount(this.systemMessage) + 5; // Account for closing tokens builder.addSystem(this.systemMessage); @@ -81,6 +77,7 @@ private int tokenCount(String message) { private static class PromptBuilder { StringBuilder promptStringBuilder = new StringBuilder(); + StringBuilder messagesStringBuilder = new StringBuilder(); void addSystem(String message) { promptStringBuilder @@ -90,21 +87,23 @@ void addSystem(String message) { } void addAssistant(String message) { - promptStringBuilder + StringBuilder tempBuilder = new StringBuilder(); + tempBuilder .append(message) .append(" [INST] "); - + messagesStringBuilder.append(tempBuilder.reverse()); } void addUser(String message) { - promptStringBuilder + StringBuilder tempBuilder = new StringBuilder(); + tempBuilder .append(message) .append(" [/INST] "); - + messagesStringBuilder.append(tempBuilder.reverse()); } String build() { - return promptStringBuilder.toString().strip(); + return promptStringBuilder.append(messagesStringBuilder.reverse()).toString().strip(); } } } \ No newline at end of file diff --git a/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java b/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java index 941009e..5be741f 100644 --- a/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java +++ b/src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java @@ -120,7 +120,7 @@ void createPayload() { HuggingFaceConfig config = HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build(); HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); - HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config); + HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens()); Optional createdPayload = promptBuilder.createPrompt(STACK); assertThat(createdPayload).isPresent(); assertThat(createdPayload.get()).isEqualTo(TEST_PAYLOAD); @@ -142,7 +142,7 @@ void createPayloadWithSystemMessage() { Identifier.random(), Identifier.random(), Role.USER)); - HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config); + HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens()); Optional createdPayload = promptBuilder.createPrompt(stack); assertThat(createdPayload).isPresent(); assertThat(createdPayload.get()).isEqualTo(TEST_PAYLOAD_WITH_SYSTEM); @@ -194,7 +194,7 @@ void truncatesContext() throws IOException { Identifier.random(), Role.USER)); thread = thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2))); - HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config); + HuggingFaceLlamaPrompt promptBuilder = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens()); Optional createdPayload = promptBuilder.createPrompt(thread); assertThat(createdPayload).isPresent(); assertThat(createdPayload.get()).isEqualTo(TEST_PAYLOAD); @@ -226,7 +226,6 @@ void validConfigValues(HuggingFaceConfigTest.ConfigItem configItem) assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException(); @Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS); assertThat(or).isNotNull(); - System.out.println(or); assertThat(or.headerMap().get("Authorization")) .isNotNull() .isEqualTo("Bearer " + config.apiKey()); @@ -239,6 +238,7 @@ void orderedCorrectly() throws IOException, InterruptedException { .maxInputTokens(100) .tokenLimit(200) .endpoint(endpoint.toString()) + .systemMessage("0") .build(); HuggingFaceLlamaPlugin plugin = new HuggingFaceLlamaPlugin<>(config); ThreadState stack = @@ -250,7 +250,7 @@ void orderedCorrectly() throws IOException, InterruptedException { Identifier.random(), Identifier.random(), Identifier.random(), - Role.SYSTEM)); + Role.ASSISTANT)); stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2))); stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3))); stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));