Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed internal representation of messages and added LLM required ordering of messages. #26

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
Optional<String> prompt = promptCreator.createPrompt(threadState);
if (prompt.isEmpty()) {
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
Instant.now(), "I'm sorry but that request was too long for me.", threadState.tail());
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
}

body.put("inputs", prompt.get());
Expand All @@ -72,6 +72,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
String llmResponse = allGeneratedText.strip().replace(prompt.get().strip(), "");
Instant timestamp = Instant.now();

return threadState.newMessageFromBot(timestamp, llmResponse);
return threadState.newMessageFromBot(timestamp, llmResponse, threadState.tail());
}
}
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
Optional<ArrayNode> prunedMessages = pruneMessages(messages, null);
if (prunedMessages.isEmpty()) {
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
Instant.now(), "I'm sorry but that request was too long for me.",fromUser);
}
body.set("messages", prunedMessages.get());

Expand All @@ -182,6 +182,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
JsonNode choice = responseBody.get("choices").get(0);
String messageContent = choice.get("message").get("content").textValue();
return threadState.newMessageFromBot(timestamp, messageContent);
return threadState.newMessageFromBot(timestamp, messageContent,fromUser);
}
}
11 changes: 9 additions & 2 deletions src/main/java/com/meta/cp4m/message/FBMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,12 @@ public record FBMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role)
implements Message {}
Role role,
Message parentMessage)
implements Message {
private static final MessageFactory<FBMessage> MESSAGE_FACTORY = MessageFactory.instance(FBMessage.class);
@Override
public Message addParentMessage(Message parentMessage) {
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
return MESSAGE_FACTORY.newMessage(timestamp(),message(),senderId(),recipientId(),instanceId(), Role.USER, parentMessage);
}
}
11 changes: 3 additions & 8 deletions src/main/java/com/meta/cp4m/message/FBMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public class FBMessageHandler implements MessageHandler<FBMessage> {
private static final Logger LOGGER = LoggerFactory.getLogger(FBMessageHandler.class);
private static final TextChunker CHUNKER = TextChunker.standard(2000);

private static final MessageFactory<FBMessage> MESSAGE_FACTORY = MessageFactory.instance(FBMessage.class);

private final String verifyToken;
private final String appSecret;

Expand Down Expand Up @@ -125,14 +127,7 @@ private List<FBMessage> postHandler(Context ctx, JsonNode body) {

@Nullable JsonNode textObject = messageObject.get("text");
if (textObject != null && textObject.isTextual()) {
FBMessage m =
new FBMessage(
timestamp,
messageId,
senderId,
recipientId,
textObject.textValue(),
Message.Role.USER);
FBMessage m = MESSAGE_FACTORY.newMessage(timestamp, textObject.textValue(), senderId, recipientId,messageId, Message.Role.USER,null);
output.add(m);
} else {
LOGGER
Expand Down
24 changes: 19 additions & 5 deletions src/main/java/com/meta/cp4m/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
import java.time.Instant;

public interface Message {

static Identifier threadId(Identifier id1, Identifier id2) {
public static Identifier threadId(Identifier id1, Identifier id2) {
if (id1.compareTo(id2) <= 0) {
return Identifier.from(id1.toString() + '|' + id2);
}
return Identifier.from(id2.toString() + '|' + id1);
}
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

static void parentMessage(Message parentMessage){
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

}

NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
public Message addParentMessage(Message parentMessage);

Instant timestamp();

Identifier instanceId();
Expand All @@ -32,13 +37,22 @@ static Identifier threadId(Identifier id1, Identifier id2) {

Role role();
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

Message parentMessage();
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

default Identifier threadId() {
return threadId(senderId(), recipientId());
}

enum Role {
ASSISTANT,
USER,
SYSTEM
ASSISTANT(0),
USER(1),
SYSTEM(2);

public final Integer priority;
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

private Role(Integer priority){
this.priority = priority;
}

}
}
7 changes: 4 additions & 3 deletions src/main/java/com/meta/cp4m/message/MessageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public interface MessageFactory<T extends Message> {
Map<Class<? extends Message>, MessageFactory<? extends Message>> FACTORY_MAP =
Stream.<FactoryContainer<?>>of(
new FactoryContainer<>(
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r)),
FBMessage.class, (t, m, si, ri, ii, r, pm) -> new FBMessage(t, ii, si, ri, m, r, pm)),
new FactoryContainer<>(
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r)))
WAMessage.class, (t, m, si, ri, ii, r, pm) -> new WAMessage(t, ii, si, ri, m, r, pm)))
.collect(
Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory));

Expand All @@ -46,7 +46,8 @@ T newMessage(
Identifier senderId,
Identifier recipientId,
Identifier instanceId,
Role role);
Role role,
Message parentMessage);
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

/** this exists to provide compiler guarantees for type matching in the FACTORY_MAP */
record FactoryContainer<T extends Message>(Class<T> clazz, MessageFactory<T> factory) {}
Expand Down
20 changes: 9 additions & 11 deletions src/main/java/com/meta/cp4m/message/ThreadState.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,17 @@ private ThreadState(T message) {
private ThreadState(ThreadState<T> old, T newMessage) {
Objects.requireNonNull(newMessage);
Preconditions.checkArgument(
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
messageFactory = old.messageFactory;
Preconditions.checkArgument(
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
this.messages =
Stream.concat(messages.stream(), Stream.of(newMessage))
.sorted(Comparator.comparing(Message::timestamp))
T mWithParentMessage = newMessage.role() == Role.USER ? (T) newMessage.addParentMessage(old.tail()): newMessage;
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
this.messages =
Stream.concat(messages.stream(), Stream.of(mWithParentMessage))
.sorted((m1,m2) -> m1.parentMessage() == m2.parentMessage() ? (m1.role().priority.compareTo(m2.role().priority)) : (m1.timestamp().compareTo(m2.timestamp())))
.collect(Collectors.toUnmodifiableList());
} else {
this.messages = ImmutableList.<T>builder().addAll(messages).add(newMessage).build();
}

Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
Expand Down Expand Up @@ -77,13 +74,13 @@ public Identifier botId() {
};
}

public T newMessageFromBot(Instant timestamp, String message) {
public T newMessageFromBot(Instant timestamp, String message, T parentMessage) {
return messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT, parentMessage);
}

public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER, this.tail());
}

public ThreadState<T> with(T message) {
Expand All @@ -97,4 +94,5 @@ public List<T> messages() {
public T tail() {
return messages.get(messages.size() - 1);
}

}
11 changes: 9 additions & 2 deletions src/main/java/com/meta/cp4m/message/WAMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,12 @@ public record WAMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role)
implements Message {}
Role role,
Message parentMessage)
implements Message {
private static final MessageFactory<WAMessage> MESSAGE_FACTORY = MessageFactory.instance(WAMessage.class);
@Override
public Message addParentMessage(Message parentMessage) {
return MESSAGE_FACTORY.newMessage(timestamp(),message(),senderId(),recipientId(),instanceId(), Role.USER, parentMessage);
}
}
11 changes: 3 additions & 8 deletions src/main/java/com/meta/cp4m/message/WAMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class WAMessageHandler implements MessageHandler<WAMessage> {
private static final String API_VERSION = "v17.0";
private static final JsonMapper MAPPER = Utils.JSON_MAPPER;
private static final Logger LOGGER = LoggerFactory.getLogger(WAMessageHandler.class);
private static final MessageFactory<WAMessage> MESSAGE_FACTORY = MessageFactory.instance(WAMessage.class);

/**
* <a
Expand Down Expand Up @@ -95,15 +96,9 @@ private List<WAMessage> post(Context ctx, WebhookPayload payload) {
continue;
}
TextWebhookMessage textMessage = (TextWebhookMessage) message;
waMessages.add(
new WAMessage(
message.timestamp(),
message.id(),
message.from(),
phoneNumberId,
textMessage.text().body(),
Message.Role.USER));
WAMessage waMessage = MESSAGE_FACTORY.newMessage(message.timestamp(), textMessage.text().body(), message.from(), phoneNumberId,message.id(), Message.Role.USER,null);
readExecutor.execute(() -> markRead(phoneNumberId, textMessage.id().toString()));
waMessages.add(waMessage);
}
});
return waMessages;
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/com/meta/cp4m/store/ChatStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package com.meta.cp4m.store;

import com.meta.cp4m.Identifier;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.util.List;
Expand All @@ -28,4 +29,6 @@ public interface ChatStore<T extends Message> {
long size();

List<ThreadState<T>> list();

ThreadState<T> get(Identifier threadId);
}
7 changes: 6 additions & 1 deletion src/main/java/com/meta/cp4m/store/MemoryStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class MemoryStore<T extends Message> implements ChatStore<T> {
}

@Override
public ThreadState<T> add(T message) {
public ThreadState<T> add(T message) {
return this.store
.asMap()
.compute(
Expand All @@ -53,4 +53,9 @@ public long size() {
public List<ThreadState<T>> list() {
return store.asMap().values().stream().toList();
}

@Override
public ThreadState<T> get(Identifier threadId){
return this.store.asMap().get(threadId);
}
}
2 changes: 1 addition & 1 deletion src/test/java/com/meta/cp4m/llm/DummyLLMPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ public String dummyResponse() {
@Override
public T handle(ThreadState<T> threadState) {
receivedThreadStates.add(threadState);
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse);
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse, threadState.tail());
}
}
11 changes: 6 additions & 5 deletions src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public class HuggingFaceLlamaPluginTest {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER,
null));

static {
SAMPLE_RESPONSE.addObject().put("generated_text", TEST_MESSAGE);
Expand Down Expand Up @@ -145,7 +146,7 @@ void createPayloadWithSystemMessage() {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder =
new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
Optional<String> createdPayload = promptBuilder.createPrompt(stack);
Expand Down Expand Up @@ -173,7 +174,7 @@ void contextTooBig() throws IOException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
FBMessage response = plugin.handle(thread);
assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me.");
}
Expand All @@ -197,7 +198,7 @@ void truncatesContext() throws IOException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER,null));
thread =
thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2)));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder =
Expand Down Expand Up @@ -256,7 +257,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/com/meta/cp4m/llm/OpenAIPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class OpenAIPluginTest {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));

static {
((ObjectNode) SAMPLE_RESPONSE)
Expand Down Expand Up @@ -198,7 +198,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ void chunkingHappens() throws IOException {
Stream.generate(() -> "0123456789.").limit(300).collect(Collectors.joining(" "));
FBMessage bigMessage =
new FBMessage(
Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER);
Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER,null);
messageHandler.respond(bigMessage);
assertThat(requests.size()).isEqualTo(300);
assertThat(requests).allSatisfy(m -> assertThat(m.body()).contains("0123456789"));
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/com/meta/cp4m/message/MessageTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ void threadId() {
Identifier id0 = Identifier.from("0");
Identifier id1 = Identifier.from("1");
Identifier id2 = Identifier.from("2");
Message message = new FBMessage(timestamp, id0, id1, id2, "", Message.Role.ASSISTANT);
Message response = new FBMessage(timestamp, id0, id2, id1, "", Message.Role.ASSISTANT);
Message message = new FBMessage(timestamp, id0, id1, id2, "", Message.Role.ASSISTANT, null);
Message response = new FBMessage(timestamp, id0, id2, id1, "", Message.Role.ASSISTANT, null);
assertThat(message.threadId()).isEqualTo(response.threadId());

message =
Expand All @@ -33,15 +33,15 @@ void threadId() {
Identifier.from("12"),
Identifier.from("34"),
"",
Message.Role.ASSISTANT);
Message.Role.ASSISTANT, null);
response =
new FBMessage(
timestamp,
id0,
Identifier.from("1"),
Identifier.from("234"),
"",
Message.Role.ASSISTANT);
Message.Role.ASSISTANT, null);
assertThat(message.threadId()).isNotEqualTo(response.threadId());
}
}
4 changes: 2 additions & 2 deletions src/test/java/com/meta/cp4m/message/ServiceTestHarness.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ public URI webserverURI() {

public URI serviceURI() {
try {
return URIBuilder.localhost()
return URIBuilder.loopbackAddress()
.appendPath(SERVICE_PATH)
.setScheme("http")
.setPort(servicePort())
.build();
} catch (URISyntaxException | UnknownHostException e) {
} catch (URISyntaxException e) {
// this should be impossible
throw new RuntimeException(e);
}
Expand Down
Loading