-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mccloskey/llama2 hf #17
Changes from 13 commits
6491ed2
3ade3a0
31dfcd5
172c587
2a1fe09
2c3fb97
ab41ff3
c802798
a70d5a3
d8441e1
422367d
e640b6a
068752d
1bdd4a4
18541c7
bea6e96
6342b3d
0154e86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,11 @@ | |
import com.fasterxml.jackson.databind.node.ObjectNode; | ||
import com.meta.cp4m.message.Message; | ||
import com.meta.cp4m.message.ThreadState; | ||
|
||
import java.io.IOException; | ||
import java.net.URI; | ||
import java.time.Instant; | ||
|
||
import org.apache.hc.client5.http.fluent.Request; | ||
import org.apache.hc.client5.http.fluent.Response; | ||
import org.apache.hc.core5.http.ContentType; | ||
|
@@ -25,17 +27,18 @@ public class HuggingFaceLlamaPlugin<T extends Message> implements LLMPlugin<T> { | |
|
||
private static final ObjectMapper MAPPER = new ObjectMapper(); | ||
private final HuggingFaceConfig config; | ||
private final HuggingFaceLlamaPrompt<T> promptCreator; | ||
|
||
private URI endpoint; | ||
|
||
public HuggingFaceLlamaPlugin(HuggingFaceConfig config) { | ||
this.config = config; | ||
this.endpoint = this.config.endpoint(); | ||
this.endpoint = this.config.endpoint(); | ||
promptCreator = new HuggingFaceLlamaPrompt<>(config); | ||
} | ||
|
||
@Override | ||
public T handle(ThreadState<T> messageStack) throws IOException { | ||
T fromUser = messageStack.tail(); | ||
|
||
@Override | ||
public T handle(ThreadState<T> threadState) throws IOException { | ||
ObjectNode body = MAPPER.createObjectNode(); | ||
ObjectNode params = MAPPER.createObjectNode(); | ||
|
||
|
@@ -45,15 +48,18 @@ public T handle(ThreadState<T> messageStack) throws IOException { | |
|
||
body.set("parameters", params); | ||
|
||
String prompt = createPrompt(messageStack); | ||
String prompt = promptCreator.createPrompt(threadState); | ||
if (prompt.equals("I'm sorry but that request was too long for me.")) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just have |
||
return threadState.newMessageFromBot(Instant.now(), prompt); | ||
} | ||
|
||
body.put("inputs", prompt); | ||
|
||
String bodyString; | ||
try { | ||
bodyString = MAPPER.writeValueAsString(body); | ||
} catch (JsonProcessingException e) { | ||
throw new RuntimeException(e); // this should be impossible | ||
throw new RuntimeException(e); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just add a comment saying this should be impossible |
||
} | ||
Response response = | ||
Request.post(endpoint) | ||
|
@@ -66,51 +72,6 @@ public T handle(ThreadState<T> messageStack) throws IOException { | |
String llmResponse = allGeneratedText.strip().replace(prompt.strip(), ""); | ||
colinmccloskey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Instant timestamp = Instant.now(); | ||
|
||
return messageStack.newMessageFromBot(timestamp, llmResponse); | ||
} | ||
|
||
public String createPrompt(ThreadState<T> MessageStack) { | ||
StringBuilder promptBuilder = new StringBuilder(); | ||
if(config.systemMessage().isPresent()){ | ||
promptBuilder.append("<s>[INST] <<SYS>>\n").append(config.systemMessage().get()).append("\n<</SYS>>\n\n"); | ||
} else if(MessageStack.messages().get(0).role() == Message.Role.SYSTEM){ | ||
promptBuilder.append("<s>[INST] <<SYS>>\n").append(MessageStack.messages().get(0).message()).append("\n<</SYS>>\n\n"); | ||
} | ||
else { | ||
promptBuilder.append("<s>[INST] "); | ||
} | ||
|
||
// The first user input is _not_ stripped | ||
boolean doStrip = false; | ||
Message.Role lastMessageSender = Message.Role.SYSTEM; | ||
|
||
for (T message : MessageStack.messages()) { | ||
String text = doStrip ? message.message().strip() : message.message(); | ||
Message.Role user = message.role(); | ||
if (user == Message.Role.SYSTEM){ | ||
continue; | ||
} | ||
boolean isUser = user == Message.Role.USER; | ||
if(isUser){ | ||
doStrip = true; | ||
} | ||
|
||
if(isUser && lastMessageSender == Message.Role.ASSISTANT){ | ||
promptBuilder.append(" </s><s>[INST] "); | ||
} | ||
if(user == Message.Role.ASSISTANT && lastMessageSender == Message.Role.USER){ | ||
promptBuilder.append(" [/INST] "); | ||
} | ||
promptBuilder.append(text); | ||
|
||
lastMessageSender = user; | ||
} | ||
if(lastMessageSender == Message.Role.ASSISTANT){ | ||
promptBuilder.append(" </s>"); | ||
} else if (lastMessageSender == Message.Role.USER){ | ||
promptBuilder.append(" [/INST]"); | ||
} | ||
|
||
return promptBuilder.toString(); | ||
return threadState.newMessageFromBot(timestamp, llmResponse); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
/* | ||
* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
package com.meta.cp4m.llm; | ||
|
||
import ai.djl.huggingface.tokenizers.Encoding; | ||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; | ||
import com.meta.cp4m.Identifier; | ||
import com.meta.cp4m.message.FBMessage; | ||
import com.meta.cp4m.message.Message; | ||
import com.meta.cp4m.message.MessageFactory; | ||
import com.meta.cp4m.message.ThreadState; | ||
|
||
import java.io.IOException; | ||
import java.net.URI; | ||
import java.net.URISyntaxException; | ||
import java.net.URL; | ||
import java.nio.file.Paths; | ||
import java.time.Instant; | ||
import java.util.*; | ||
|
||
public class HuggingFaceLlamaPrompt<T extends Message> { | ||
|
||
private final String systemMessage; | ||
private final long maxInputTokens; | ||
private final HuggingFaceTokenizer tokenizer; | ||
|
||
public HuggingFaceLlamaPrompt(HuggingFaceConfig config) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably best not to strongly couple unrelated objects There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh okay I did this because it's how it was done in the openAI plugin, should I just pass in all the variables used in the config then? |
||
|
||
this.systemMessage = config.systemMessage(); | ||
this.maxInputTokens = config.maxInputTokens(); | ||
URL llamaTokenizerUrl = | ||
Objects.requireNonNull( | ||
HuggingFaceLlamaPrompt.class.getClassLoader().getResource("llamaTokenizer.json")); | ||
URI llamaTokenizer; | ||
try { | ||
llamaTokenizer = llamaTokenizerUrl.toURI(); | ||
tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(llamaTokenizer)); | ||
|
||
} catch (URISyntaxException | IOException e) { | ||
// this should be impossible | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
public String createPrompt(ThreadState<T> threadState) { | ||
|
||
PromptBuilder builder = new PromptBuilder(); | ||
|
||
int totalTokens = 5; // Account for closing tokens | ||
Message systemMessage = threadState.messages().get(0).role().equals(Message.Role.SYSTEM) ? threadState.messages().get(0) : MessageFactory.instance(FBMessage.class) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. T168681714 |
||
.newMessage( | ||
Instant.now(), | ||
this.systemMessage, | ||
Identifier.random(), | ||
Identifier.random(), | ||
Identifier.random(), | ||
Message.Role.SYSTEM); | ||
ArrayList<Message> output = new ArrayList<>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a way we can avoid creating a intermediary list object here? If not lets at least initiate the size of the array list because we know that. |
||
totalTokens += tokenCount(systemMessage.message()); | ||
for (int i = threadState.messages().size() - 1; i >= 0; i--) { | ||
Message m = threadState.messages().get(i); | ||
|
||
if (m.role().equals(Message.Role.SYSTEM)) { | ||
colinmccloskey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
continue; // system has already been counted | ||
} | ||
totalTokens += tokenCount(m.message()); | ||
if (totalTokens > maxInputTokens) { | ||
break; | ||
} | ||
output.add(0, m); | ||
} | ||
if (output.isEmpty()) { | ||
return "I'm sorry but that request was too long for me."; | ||
} | ||
output.add(0, systemMessage); | ||
|
||
for (Message message : output) { | ||
switch (message.role()) { | ||
case SYSTEM -> builder.addSystem(message); | ||
case USER -> builder.addUser(message); | ||
case ASSISTANT -> builder.addAssistant(message); | ||
} | ||
} | ||
|
||
return builder.build(); | ||
} | ||
|
||
private int tokenCount(String message) { | ||
Encoding encoding = tokenizer.encode(message); | ||
return encoding.getTokens().length - 1; | ||
colinmccloskey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
private static class PromptBuilder { | ||
colinmccloskey marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
int totalTokens = 5; | ||
StringBuilder promptStringBuilder = new StringBuilder(); | ||
|
||
void addSystem(Message message) { | ||
promptStringBuilder | ||
.append("<s>[INST] <<SYS>>\n") | ||
.append(message.message()) | ||
.append("\n<</SYS>>\n\n"); | ||
} | ||
|
||
void addAssistant(Message message) { | ||
promptStringBuilder | ||
.append(message.message()) | ||
.append(" </s><s>[INST] "); | ||
|
||
} | ||
|
||
void addUser(Message message) { | ||
promptStringBuilder | ||
.append(message.message()) | ||
.append(" [/INST] "); | ||
|
||
} | ||
|
||
String build() { | ||
return promptStringBuilder.toString().strip(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍