Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
colinmccloskey committed Nov 2, 2023
1 parent bea6e96 commit 6342b3d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class HuggingFaceLlamaPlugin<T extends Message> implements LLMPlugin<T> {
public HuggingFaceLlamaPlugin(HuggingFaceConfig config) {
this.config = config;
this.endpoint = this.config.endpoint();
promptCreator = new HuggingFaceLlamaPrompt<>(config);
promptCreator = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
}

@Override
Expand Down
25 changes: 12 additions & 13 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import com.meta.cp4m.Identifier;
import com.meta.cp4m.message.FBMessage;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.MessageFactory;
import com.meta.cp4m.message.ThreadState;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.time.Instant;
import java.util.*;

public class HuggingFaceLlamaPrompt<T extends Message> {
Expand All @@ -30,10 +26,10 @@ public class HuggingFaceLlamaPrompt<T extends Message> {
private final long maxInputTokens;
private final HuggingFaceTokenizer tokenizer;

public HuggingFaceLlamaPrompt(HuggingFaceConfig config) {
public HuggingFaceLlamaPrompt(String systemMessage, long maxInputTokens) {

this.systemMessage = config.systemMessage();
this.maxInputTokens = config.maxInputTokens();
this.systemMessage = systemMessage;
this.maxInputTokens = maxInputTokens;
URL llamaTokenizerUrl =
Objects.requireNonNull(
HuggingFaceLlamaPrompt.class.getClassLoader().getResource("llamaTokenizer.json"));
Expand All @@ -51,7 +47,7 @@ public HuggingFaceLlamaPrompt(HuggingFaceConfig config) {
public Optional<String> createPrompt(ThreadState<T> threadState) {

PromptBuilder builder = new PromptBuilder();

int totalTokens = tokenCount(this.systemMessage) + 5; // Account for closing tokens
builder.addSystem(this.systemMessage);

Expand Down Expand Up @@ -81,6 +77,7 @@ private int tokenCount(String message) {
private static class PromptBuilder {

StringBuilder promptStringBuilder = new StringBuilder();
StringBuilder messagesStringBuilder = new StringBuilder();

void addSystem(String message) {
promptStringBuilder
Expand All @@ -90,21 +87,23 @@ void addSystem(String message) {
}

void addAssistant(String message) {
promptStringBuilder
StringBuilder tempBuilder = new StringBuilder();
tempBuilder
.append(message)
.append(" </s><s>[INST] ");

messagesStringBuilder.append(tempBuilder.reverse());
}

void addUser(String message) {
promptStringBuilder
StringBuilder tempBuilder = new StringBuilder();
tempBuilder
.append(message)
.append(" [/INST] ");

messagesStringBuilder.append(tempBuilder.reverse());
}

String build() {
return promptStringBuilder.toString().strip();
return promptStringBuilder.append(messagesStringBuilder.reverse()).toString().strip();
}
}
}
10 changes: 5 additions & 5 deletions src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void createPayload() {
HuggingFaceConfig config =
HuggingFaceConfig.builder(apiKey).endpoint(endpoint.toString()).tokenLimit(100).build();
HuggingFaceLlamaPlugin<FBMessage> plugin = new HuggingFaceLlamaPlugin<>(config);
HuggingFaceLlamaPrompt<FBMessage> promptBuilder = new HuggingFaceLlamaPrompt<>(config);
HuggingFaceLlamaPrompt<FBMessage> promptBuilder = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
Optional<String> createdPayload = promptBuilder.createPrompt(STACK);
assertThat(createdPayload).isPresent();
assertThat(createdPayload.get()).isEqualTo(TEST_PAYLOAD);
Expand All @@ -142,7 +142,7 @@ void createPayloadWithSystemMessage() {
Identifier.random(),
Identifier.random(),
Role.USER));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder = new HuggingFaceLlamaPrompt<>(config);
HuggingFaceLlamaPrompt<FBMessage> promptBuilder = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
Optional<String> createdPayload = promptBuilder.createPrompt(stack);
assertThat(createdPayload).isPresent();
assertThat(createdPayload.get()).isEqualTo(TEST_PAYLOAD_WITH_SYSTEM);
Expand Down Expand Up @@ -194,7 +194,7 @@ void truncatesContext() throws IOException {
Identifier.random(),
Role.USER));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2)));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder = new HuggingFaceLlamaPrompt<>(config);
HuggingFaceLlamaPrompt<FBMessage> promptBuilder = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
Optional<String> createdPayload = promptBuilder.createPrompt(thread);
assertThat(createdPayload).isPresent();
assertThat(createdPayload.get()).isEqualTo(TEST_PAYLOAD);
Expand Down Expand Up @@ -226,7 +226,6 @@ void validConfigValues(HuggingFaceConfigTest.ConfigItem configItem)
assertThatCode(() -> STACK.with(message)).doesNotThrowAnyException();
@Nullable OutboundRequest or = HuggingFaceLlamaRequests.poll(500, TimeUnit.MILLISECONDS);
assertThat(or).isNotNull();
System.out.println(or);
assertThat(or.headerMap().get("Authorization"))
.isNotNull()
.isEqualTo("Bearer " + config.apiKey());
Expand All @@ -239,6 +238,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
.maxInputTokens(100)
.tokenLimit(200)
.endpoint(endpoint.toString())
.systemMessage("0")
.build();
HuggingFaceLlamaPlugin<FBMessage> plugin = new HuggingFaceLlamaPlugin<>(config);
ThreadState<FBMessage> stack =
Expand All @@ -250,7 +250,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.SYSTEM));
Role.ASSISTANT));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
Expand Down

0 comments on commit 6342b3d

Please sign in to comment.