Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed internal representation of messages and added LLM required ordering of messages. #26

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/java/com/meta/cp4m/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public Service(
}

void handle(Context ctx) {
List<T> messages = handler.processRequest(ctx);
List<T> messages = handler.processRequest(ctx, store);
// TODO: once we have a non-volatile store, on startup send stored but not replied to messages
for (T m : messages) {
ThreadState<T> thread = store.add(m);
Expand Down Expand Up @@ -69,6 +69,7 @@ private void execute(ThreadState<T> thread) {
LOGGER.error("failed to communicate with LLM", e);
return;
}
llmResponse = thread.tail();
store.add(llmResponse);
try {
handler.respond(llmResponse);
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
Optional<String> prompt = promptCreator.createPrompt(threadState);
if (prompt.isEmpty()) {
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
Instant.now(), "I'm sorry but that request was too long for me.", threadState.tail());
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
}

body.put("inputs", prompt.get());
Expand All @@ -72,6 +72,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
String llmResponse = allGeneratedText.strip().replace(prompt.get().strip(), "");
Instant timestamp = Instant.now();

return threadState.newMessageFromBot(timestamp, llmResponse);
return threadState.newMessageFromBot(timestamp, llmResponse, threadState.tail());
}
}
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
Optional<ArrayNode> prunedMessages = pruneMessages(messages, null);
if (prunedMessages.isEmpty()) {
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.");
Instant.now(), "I'm sorry but that request was too long for me.",fromUser);
}
body.set("messages", prunedMessages.get());

Expand All @@ -182,6 +182,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
JsonNode choice = responseBody.get("choices").get(0);
String messageContent = choice.get("message").get("content").textValue();
return threadState.newMessageFromBot(timestamp, messageContent);
return threadState.newMessageFromBot(timestamp, messageContent,fromUser);
}
}
3 changes: 2 additions & 1 deletion src/main/java/com/meta/cp4m/message/FBMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ public record FBMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role)
Role role,
Message parentMessage)
implements Message {}
20 changes: 9 additions & 11 deletions src/main/java/com/meta/cp4m/message/FBMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.meta.cp4m.Identifier;
import com.meta.cp4m.store.ChatStore;
import io.javalin.http.BadRequestResponse;
import io.javalin.http.Context;
import io.javalin.http.HandlerType;
Expand Down Expand Up @@ -40,6 +41,8 @@ public class FBMessageHandler implements MessageHandler<FBMessage> {
private static final Logger LOGGER = LoggerFactory.getLogger(FBMessageHandler.class);
private static final TextChunker CHUNKER = TextChunker.standard(2000);

private static final MessageFactory<FBMessage> MESSAGE_FACTORY = MessageFactory.instance(FBMessage.class);

private final String verifyToken;
private final String appSecret;

Expand Down Expand Up @@ -92,14 +95,14 @@ public FBMessageHandler(String verifyToken, String pageAccessToken, String appSe
}

@Override
public List<FBMessage> processRequest(Context ctx) {
public List<FBMessage> processRequest(Context ctx, ChatStore<FBMessage> store) {
try {
switch (ctx.handlerType()) {
case GET -> {
return getHandler(ctx);
}
case POST -> {
return postHandler(ctx);
return postHandler(ctx,store);
}
}
} catch (JsonProcessingException | NullPointerException e) {
Expand Down Expand Up @@ -133,7 +136,7 @@ String hmac(String body) {
return MetaHandlerUtils.hmac(body, appSecret);
}

private List<FBMessage> postHandler(Context ctx) throws JsonProcessingException {
private List<FBMessage> postHandler(Context ctx, ChatStore<FBMessage> store) throws JsonProcessingException {
MetaHandlerUtils.postHeaderValidator(ctx, appSecret);

String bodyString = ctx.body();
Expand Down Expand Up @@ -178,14 +181,9 @@ private List<FBMessage> postHandler(Context ctx) throws JsonProcessingException

@Nullable JsonNode textObject = messageObject.get("text");
if (textObject != null && textObject.isTextual()) {
FBMessage m =
new FBMessage(
timestamp,
messageId,
senderId,
recipientId,
textObject.textValue(),
Message.Role.USER);
ThreadState<FBMessage> thread = store.get(Message.threadId(senderId,recipientId));
FBMessage parentMessage = thread == null ? null : thread.tail();
FBMessage m = MESSAGE_FACTORY.newMessage(timestamp, textObject.textValue(), senderId, recipientId,messageId, Message.Role.USER,parentMessage);
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
output.add(m);
} else {
LOGGER
Expand Down
8 changes: 7 additions & 1 deletion src/main/java/com/meta/cp4m/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@

public interface Message {

static Identifier threadId(Identifier id1, Identifier id2) {
public static Identifier threadId(Identifier id1, Identifier id2) {
if (id1.compareTo(id2) <= 0) {
return Identifier.from(id1.toString() + '|' + id2);
}
return Identifier.from(id2.toString() + '|' + id1);
}
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

static void parentMessage(Message parentMessage){
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

}

NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
Instant timestamp();

Identifier instanceId();
Expand All @@ -32,6 +36,8 @@ static Identifier threadId(Identifier id1, Identifier id2) {

Role role();
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

Message parentMessage();
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

default Identifier threadId() {
return threadId(senderId(), recipientId());
}
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/com/meta/cp4m/message/MessageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public interface MessageFactory<T extends Message> {
Map<Class<? extends Message>, MessageFactory<? extends Message>> FACTORY_MAP =
Stream.<FactoryContainer<?>>of(
new FactoryContainer<>(
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r)),
FBMessage.class, (t, m, si, ri, ii, r, pm) -> new FBMessage(t, ii, si, ri, m, r, pm)),
new FactoryContainer<>(
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r)))
WAMessage.class, (t, m, si, ri, ii, r, pm) -> new WAMessage(t, ii, si, ri, m, r, pm)))
.collect(
Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory));

Expand All @@ -46,7 +46,8 @@ T newMessage(
Identifier senderId,
Identifier recipientId,
Identifier instanceId,
Role role);
Role role,
Message parentMessage);
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

/** this exists to provide compiler guarantees for type matching in the FACTORY_MAP */
record FactoryContainer<T extends Message>(Class<T> clazz, MessageFactory<T> factory) {}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/com/meta/cp4m/message/MessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package com.meta.cp4m.message;

import com.meta.cp4m.store.ChatStore;
import io.javalin.http.Context;
import io.javalin.http.HandlerType;
import java.io.IOException;
Expand All @@ -22,7 +23,7 @@ public interface MessageHandler<T extends Message> {
* @param ctx the context corresponding to an incoming request
* @return return a {@link Message} object if appropriate
*/
List<T> processRequest(Context ctx);
List<T> processRequest(Context ctx, ChatStore<T> store);
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

/**
* The method needed to respond to a message from a user
Expand Down
24 changes: 15 additions & 9 deletions src/main/java/com/meta/cp4m/message/ThreadState.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,19 @@ private ThreadState(ThreadState<T> old, T newMessage) {
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
this.messages =
// if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
// this.messages =
// Stream.concat(messages.stream(), Stream.of(newMessage))
// .sorted(Comparator.comparing(Message::timestamp))
// .collect(Collectors.toUnmodifiableList());
// } else {
// this.messages = ImmutableList.<T>builder().addAll(messages).add(newMessage).build();
// }

this.messages =
Stream.concat(messages.stream(), Stream.of(newMessage))
.sorted(Comparator.comparing(Message::timestamp))
.sorted((m1,m2) -> m1.parentMessage() == m2.parentMessage() ? (m1.role().compareTo(m2.role())) : (m1.timestamp().compareTo(m2.timestamp())))
.collect(Collectors.toUnmodifiableList());
} else {
this.messages = ImmutableList.<T>builder().addAll(messages).add(newMessage).build();
}

Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
Expand Down Expand Up @@ -77,13 +82,13 @@ public Identifier botId() {
};
}

public T newMessageFromBot(Instant timestamp, String message) {
public T newMessageFromBot(Instant timestamp, String message, T parentMessage) {
return messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT, parentMessage);
}

public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER, this.tail());
}

public ThreadState<T> with(T message) {
Expand All @@ -97,4 +102,5 @@ public List<T> messages() {
public T tail() {
return messages.get(messages.size() - 1);
}

}
3 changes: 2 additions & 1 deletion src/main/java/com/meta/cp4m/message/WAMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ public record WAMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role)
Role role,
Message parentMessage)
implements Message {}
19 changes: 8 additions & 11 deletions src/main/java/com/meta/cp4m/message/WAMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.meta.cp4m.message.webhook.whatsapp.Utils;
import com.meta.cp4m.message.webhook.whatsapp.WebhookMessage;
import com.meta.cp4m.message.webhook.whatsapp.WebhookPayload;
import com.meta.cp4m.store.ChatStore;
import io.javalin.http.Context;
import io.javalin.http.HandlerType;
import java.io.IOException;
Expand All @@ -40,6 +41,7 @@ public class WAMessageHandler implements MessageHandler<WAMessage> {
private static final String API_VERSION = "v17.0";
private static final JsonMapper MAPPER = Utils.JSON_MAPPER;
private static final Logger LOGGER = LoggerFactory.getLogger(FBMessageHandler.class);
private static final MessageFactory<WAMessage> MESSAGE_FACTORY = MessageFactory.instance(WAMessage.class);

/**
* <a
Expand Down Expand Up @@ -85,7 +87,7 @@ public WAMessageHandler(WAMessengerConfig config) {
}

@Override
public List<WAMessage> processRequest(Context ctx) {
public List<WAMessage> processRequest(Context ctx, ChatStore<WAMessage> store) {

try {
switch (ctx.handlerType()) {
Expand All @@ -95,7 +97,7 @@ public List<WAMessage> processRequest(Context ctx) {
return Collections.emptyList();
}
case POST -> {
return postHandler(ctx);
return postHandler(ctx,store);
}
}
} catch (RuntimeException e) {
Expand All @@ -108,7 +110,7 @@ public List<WAMessage> processRequest(Context ctx) {
throw new UnsupportedOperationException("Only accepting get and post methods");
}

List<WAMessage> postHandler(Context ctx) {
List<WAMessage> postHandler(Context ctx, ChatStore<WAMessage> store) {
MetaHandlerUtils.postHeaderValidator(ctx, appSecret);
String bodyString = ctx.body();
WebhookPayload payload;
Expand Down Expand Up @@ -142,14 +144,9 @@ List<WAMessage> postHandler(Context ctx) {
continue;
}
TextWebhookMessage textMessage = (TextWebhookMessage) message;
waMessages.add(
new WAMessage(
message.timestamp(),
message.id(),
message.from(),
phoneNumberId,
textMessage.text().body(),
Message.Role.USER));
ThreadState<WAMessage> thread = store.get(Message.threadId(message.from(),phoneNumberId));
WAMessage parentMessage = thread == null ? null : thread.tail();
WAMessage m = MESSAGE_FACTORY.newMessage(message.timestamp(), textMessage.text().body(), message.from(), phoneNumberId,message.id(), Message.Role.USER,parentMessage);
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
readExecutor.execute(() -> markRead(phoneNumberId, textMessage.id().toString()));
}
});
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/com/meta/cp4m/store/ChatStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package com.meta.cp4m.store;

import com.meta.cp4m.Identifier;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.util.List;
Expand All @@ -28,4 +29,6 @@ public interface ChatStore<T extends Message> {
long size();

List<ThreadState<T>> list();

ThreadState<T> get(Identifier threadId);
}
7 changes: 6 additions & 1 deletion src/main/java/com/meta/cp4m/store/MemoryStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class MemoryStore<T extends Message> implements ChatStore<T> {
}

@Override
public ThreadState<T> add(T message) {
public ThreadState<T> add(T message) {
return this.store
.asMap()
.compute(
Expand All @@ -53,4 +53,9 @@ public long size() {
public List<ThreadState<T>> list() {
return store.asMap().values().stream().toList();
}

@Override
public ThreadState<T> get(Identifier threadId){
return this.store.asMap().get(threadId);
}
}
2 changes: 1 addition & 1 deletion src/test/java/com/meta/cp4m/llm/DummyLLMPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ public String dummyResponse() {
@Override
public T handle(ThreadState<T> threadState) {
receivedThreadStates.add(threadState);
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse);
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse, threadState.tail());
}
}
11 changes: 6 additions & 5 deletions src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ public class HuggingFaceLlamaPluginTest {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER,
null));

static {
SAMPLE_RESPONSE.addObject().put("generated_text", TEST_MESSAGE);
Expand Down Expand Up @@ -145,7 +146,7 @@ void createPayloadWithSystemMessage() {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder =
new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
Optional<String> createdPayload = promptBuilder.createPrompt(stack);
Expand Down Expand Up @@ -173,7 +174,7 @@ void contextTooBig() throws IOException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
FBMessage response = plugin.handle(thread);
assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me.");
}
Expand All @@ -197,7 +198,7 @@ void truncatesContext() throws IOException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER,null));
thread =
thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2)));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder =
Expand Down Expand Up @@ -256,7 +257,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER));
Role.USER, null));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
Expand Down
Loading
Loading