Skip to content

Commit

Permalink
Merge branch 'main' into container_registry
Browse files Browse the repository at this point in the history
  • Loading branch information
hunterjackson authored Nov 6, 2023
2 parents 0fcc36d + 65d1246 commit e09c332
Show file tree
Hide file tree
Showing 10 changed files with 93,792 additions and 291 deletions.
2 changes: 1 addition & 1 deletion .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@
<artifactId>jtokkit</artifactId>
<version>0.6.1</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.23.0</version>
</dependency>
</dependencies>
<build>
<finalName>${custom.jarName}</finalName>
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ public Map<Long, Double> logitBias() {
return logitBias;
}

public Optional<String> systemMessage() {
return Optional.ofNullable(systemMessage);
public String systemMessage() {
if (systemMessage == null) {
return "You're a helpful assistant.";
}
return systemMessage;
}

public long maxInputTokens() {
Expand Down
72 changes: 17 additions & 55 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;

import java.io.IOException;
import java.net.URI;
import java.time.Instant;
import java.util.Optional;

import org.apache.hc.client5.http.fluent.Request;
import org.apache.hc.client5.http.fluent.Response;
import org.apache.hc.core5.http.ContentType;
Expand All @@ -25,17 +28,18 @@ public class HuggingFaceLlamaPlugin<T extends Message> implements LLMPlugin<T> {

private static final ObjectMapper MAPPER = new ObjectMapper();
private final HuggingFaceConfig config;
private final HuggingFaceLlamaPrompt<T> promptCreator;

private URI endpoint;

public HuggingFaceLlamaPlugin(HuggingFaceConfig config) {
this.config = config;
this.endpoint = this.config.endpoint();
this.endpoint = this.config.endpoint();
promptCreator = new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
}

@Override
public T handle(ThreadState<T> messageStack) throws IOException {
T fromUser = messageStack.tail();

@Override
public T handle(ThreadState<T> threadState) throws IOException {
ObjectNode body = MAPPER.createObjectNode();
ObjectNode params = MAPPER.createObjectNode();

Expand All @@ -45,9 +49,12 @@ public T handle(ThreadState<T> messageStack) throws IOException {

body.set("parameters", params);

String prompt = createPrompt(messageStack);
Optional<String> prompt = promptCreator.createPrompt(threadState);
if (prompt.isEmpty()) {
return threadState.newMessageFromBot(Instant.now(), "I'm sorry but that request was too long for me.");
}

body.put("inputs", prompt);
body.put("inputs", prompt.get());

String bodyString;
try {
Expand All @@ -63,54 +70,9 @@ public T handle(ThreadState<T> messageStack) throws IOException {

JsonNode responseBody = MAPPER.readTree(response.returnContent().asBytes());
String allGeneratedText = responseBody.get(0).get("generated_text").textValue();
String llmResponse = allGeneratedText.strip().replace(prompt.strip(), "");
String llmResponse = allGeneratedText.strip().replace(prompt.get().strip(), "");
Instant timestamp = Instant.now();

return messageStack.newMessageFromBot(timestamp, llmResponse);
}

public String createPrompt(ThreadState<T> MessageStack) {
StringBuilder promptBuilder = new StringBuilder();
if(config.systemMessage().isPresent()){
promptBuilder.append("<s>[INST] <<SYS>>\n").append(config.systemMessage().get()).append("\n<</SYS>>\n\n");
} else if(MessageStack.messages().get(0).role() == Message.Role.SYSTEM){
promptBuilder.append("<s>[INST] <<SYS>>\n").append(MessageStack.messages().get(0).message()).append("\n<</SYS>>\n\n");
}
else {
promptBuilder.append("<s>[INST] ");
}

// The first user input is _not_ stripped
boolean doStrip = false;
Message.Role lastMessageSender = Message.Role.SYSTEM;

for (T message : MessageStack.messages()) {
String text = doStrip ? message.message().strip() : message.message();
Message.Role user = message.role();
if (user == Message.Role.SYSTEM){
continue;
}
boolean isUser = user == Message.Role.USER;
if(isUser){
doStrip = true;
}

if(isUser && lastMessageSender == Message.Role.ASSISTANT){
promptBuilder.append(" </s><s>[INST] ");
}
if(user == Message.Role.ASSISTANT && lastMessageSender == Message.Role.USER){
promptBuilder.append(" [/INST] ");
}
promptBuilder.append(text);

lastMessageSender = user;
}
if(lastMessageSender == Message.Role.ASSISTANT){
promptBuilder.append(" </s>");
} else if (lastMessageSender == Message.Role.USER){
promptBuilder.append(" [/INST]");
}

return promptBuilder.toString();
return threadState.newMessageFromBot(timestamp, llmResponse);
}
}
}
109 changes: 109 additions & 0 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPrompt.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.llm;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;

public class HuggingFaceLlamaPrompt<T extends Message> {

private final String systemMessage;
private final long maxInputTokens;
private final HuggingFaceTokenizer tokenizer;

public HuggingFaceLlamaPrompt(String systemMessage, long maxInputTokens) {

this.systemMessage = systemMessage;
this.maxInputTokens = maxInputTokens;
URL llamaTokenizerUrl =
Objects.requireNonNull(
HuggingFaceLlamaPrompt.class.getClassLoader().getResource("llamaTokenizer.json"));
URI llamaTokenizer;
try {
llamaTokenizer = llamaTokenizerUrl.toURI();
tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(llamaTokenizer));

} catch (URISyntaxException | IOException e) {
// this should be impossible
throw new RuntimeException(e);
}
}

public Optional<String> createPrompt(ThreadState<T> threadState) {

PromptBuilder builder = new PromptBuilder();

int totalTokens = tokenCount(this.systemMessage) + 5; // Account for closing tokens
builder.addSystem(this.systemMessage);

for (int i = threadState.messages().size() - 1; i >= 0; i--) {
Message m = threadState.messages().get(i);
totalTokens += tokenCount(m.message());
if (totalTokens > maxInputTokens) {
if (i == threadState.messages().size() - 1){
return Optional.empty();
}
break;
}
switch (m.role()) {
case USER -> builder.addUser(m.message());
case ASSISTANT -> builder.addAssistant(m.message());
}
}

return Optional.of(builder.build());
}

private int tokenCount(String message) {
Encoding encoding = tokenizer.encode(message);
return encoding.getTokens().length;
}

private static class PromptBuilder {

StringBuilder promptStringBuilder = new StringBuilder();
StringBuilder messagesStringBuilder = new StringBuilder();

void addSystem(String message) {
promptStringBuilder
.append("<s>[INST] <<SYS>>\n")
.append(message)
.append("\n<</SYS>>\n\n");
}

void addAssistant(String message) {
StringBuilder tempBuilder = new StringBuilder();
tempBuilder
.append(message)
.append(" </s><s>[INST] ");
messagesStringBuilder.append(tempBuilder.reverse());
}

void addUser(String message) {
StringBuilder tempBuilder = new StringBuilder();
tempBuilder
.append(message)
.append(" [/INST] ");
messagesStringBuilder.append(tempBuilder.reverse());
}

String build() {
return promptStringBuilder.append(messagesStringBuilder.reverse()).toString().strip();
}
}
}
7 changes: 5 additions & 2 deletions src/main/java/com/meta/cp4m/llm/OpenAIConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ public Map<Long, Double> logitBias() {
return logitBias;
}

public Optional<String> systemMessage() {
return Optional.ofNullable(systemMessage);
public String systemMessage() {
if (systemMessage == null) {
return "You're a helpful assistant.";
}
return systemMessage;
}

public long maxInputTokens() {
Expand Down
12 changes: 4 additions & 8 deletions src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,10 @@ public T handle(ThreadState<T> threadState) throws IOException {
}

ArrayNode messages = MAPPER.createArrayNode();
config
.systemMessage()
.ifPresent(
m ->
messages
.addObject()
.put("role", Role.SYSTEM.toString().toLowerCase())
.put("content", m));
messages
.addObject()
.put("role", Role.SYSTEM.toString().toLowerCase())
.put("content", config.systemMessage());
for (T message : threadState.messages()) {
messages
.addObject()
Expand Down
Loading

0 comments on commit e09c332

Please sign in to comment.