Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed internal representation of messages and added LLM required ordering of messages. #26

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/main/java/com/meta/cp4m/message/FBMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package com.meta.cp4m.message;

import com.meta.cp4m.Identifier;
import org.checkerframework.checker.lock.qual.NewObject;

import java.time.Instant;

public record FBMessage(
Expand All @@ -17,5 +19,11 @@ public record FBMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role)
implements Message {}
Role role,
Message parentMessage)
implements Message {
@Override
public @NewObject Message withParentMessage(Message parentMessage) {
return new FBMessage(timestamp(),instanceId(),senderId(),recipientId(),message(), role(), parentMessage);
}
}
11 changes: 3 additions & 8 deletions src/main/java/com/meta/cp4m/message/FBMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public class FBMessageHandler implements MessageHandler<FBMessage> {
private static final Logger LOGGER = LoggerFactory.getLogger(FBMessageHandler.class);
private static final TextChunker CHUNKER = TextChunker.standard(2000);

private static final MessageFactory<FBMessage> MESSAGE_FACTORY = MessageFactory.instance(FBMessage.class);

private final String verifyToken;
private final String appSecret;

Expand Down Expand Up @@ -125,14 +127,7 @@ private List<FBMessage> postHandler(Context ctx, JsonNode body) {

@Nullable JsonNode textObject = messageObject.get("text");
if (textObject != null && textObject.isTextual()) {
FBMessage m =
new FBMessage(
timestamp,
messageId,
senderId,
recipientId,
textObject.textValue(),
Message.Role.USER);
FBMessage m = MESSAGE_FACTORY.newMessage(timestamp, textObject.textValue(), senderId, recipientId,messageId, Message.Role.USER);
output.add(m);
} else {
LOGGER
Expand Down
27 changes: 22 additions & 5 deletions src/main/java/com/meta/cp4m/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@
package com.meta.cp4m.message;

import com.meta.cp4m.Identifier;
import org.checkerframework.checker.lock.qual.NewObject;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.time.Instant;

public interface Message {

static Identifier threadId(Identifier id1, Identifier id2) {
private static Identifier threadId(Identifier id1, Identifier id2) {
if (id1.compareTo(id2) <= 0) {
return Identifier.from(id1.toString() + '|' + id2);
}
return Identifier.from(id2.toString() + '|' + id1);
}

public <T extends Message> @NewObject T withParentMessage(Message parentMessage);

NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
Instant timestamp();

Identifier instanceId();
Expand All @@ -32,13 +36,26 @@ static Identifier threadId(Identifier id1, Identifier id2) {

Role role();
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved

@Nullable Message parentMessage();

default Identifier threadId() {
return threadId(senderId(), recipientId());
}

enum Role {
ASSISTANT,
USER,
SYSTEM
ASSISTANT(0),
USER(1),
SYSTEM(2);

private final int priority;

private Role(Integer priority){
this.priority = priority;
}

public int getPriority(){
NanditaRao marked this conversation as resolved.
Show resolved Hide resolved
return this.priority;
}

}
}
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/message/MessageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public interface MessageFactory<T extends Message> {
Map<Class<? extends Message>, MessageFactory<? extends Message>> FACTORY_MAP =
Stream.<FactoryContainer<?>>of(
new FactoryContainer<>(
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r)),
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r,null)),
new FactoryContainer<>(
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r)))
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r,null)))
.collect(
Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory));

Expand Down
26 changes: 16 additions & 10 deletions src/main/java/com/meta/cp4m/message/ThreadState.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,30 @@ private ThreadState(T message) {
private ThreadState(ThreadState<T> old, T newMessage) {
Objects.requireNonNull(newMessage);
Preconditions.checkArgument(
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
messageFactory = old.messageFactory;
Preconditions.checkArgument(
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
if (newMessage.timestamp().isBefore(old.tail().timestamp())) {
this.messages =
Stream.concat(messages.stream(), Stream.of(newMessage))
.sorted(Comparator.comparing(Message::timestamp))
T mWithParentMessage = newMessage.role() == Role.USER ? newMessage.withParentMessage(old.tail()): newMessage;
this.messages =
Stream.concat(messages.stream(), Stream.of(mWithParentMessage))
.sorted((m1,m2) -> m1.parentMessage() == m2.parentMessage() ? compare(m1.role().getPriority(),m2.role().getPriority()) : (m1.timestamp().compareTo(m2.timestamp())))
.collect(Collectors.toUnmodifiableList());
} else {
this.messages = ImmutableList.<T>builder().addAll(messages).add(newMessage).build();
}

Identifier oldUserId = old.userId();
Identifier userId = userId();
Identifier oldBotId = old.botId();
Identifier botId = botId();
Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this thread state");
}

private int compare(int priority1, int priority2){
return priority1 > priority2 ? +1 : priority1 < priority2 ? -1 : 0;
}

public static <T extends Message> ThreadState<T> of(T message) {
return new ThreadState<>(message);
}
Expand All @@ -78,8 +82,9 @@ public Identifier botId() {
}

public T newMessageFromBot(Instant timestamp, String message) {
return messageFactory.newMessage(
T newMessage = messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
return newMessage.withParentMessage(tail());
}

public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
Expand All @@ -97,4 +102,5 @@ public List<T> messages() {
public T tail() {
return messages.get(messages.size() - 1);
}

}
12 changes: 10 additions & 2 deletions src/main/java/com/meta/cp4m/message/WAMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package com.meta.cp4m.message;

import com.meta.cp4m.Identifier;
import org.checkerframework.checker.lock.qual.NewObject;

import java.time.Instant;

public record WAMessage(
Expand All @@ -17,5 +19,11 @@ public record WAMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role)
implements Message {}
Role role,
Message parentMessage)
implements Message {
@Override
public @NewObject Message withParentMessage(Message parentMessage) {
return new WAMessage(timestamp(),instanceId(),senderId(),recipientId(),message(), role(), parentMessage);
}
}
11 changes: 3 additions & 8 deletions src/main/java/com/meta/cp4m/message/WAMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class WAMessageHandler implements MessageHandler<WAMessage> {
private static final String API_VERSION = "v17.0";
private static final JsonMapper MAPPER = Utils.JSON_MAPPER;
private static final Logger LOGGER = LoggerFactory.getLogger(WAMessageHandler.class);
private static final MessageFactory<WAMessage> MESSAGE_FACTORY = MessageFactory.instance(WAMessage.class);

/**
* <a
Expand Down Expand Up @@ -95,15 +96,9 @@ private List<WAMessage> post(Context ctx, WebhookPayload payload) {
continue;
}
TextWebhookMessage textMessage = (TextWebhookMessage) message;
waMessages.add(
new WAMessage(
message.timestamp(),
message.id(),
message.from(),
phoneNumberId,
textMessage.text().body(),
Message.Role.USER));
WAMessage waMessage = MESSAGE_FACTORY.newMessage(message.timestamp(), textMessage.text().body(), message.from(), phoneNumberId,message.id(), Message.Role.USER);
readExecutor.execute(() -> markRead(phoneNumberId, textMessage.id().toString()));
waMessages.add(waMessage);
}
});
return waMessages;
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/com/meta/cp4m/store/ChatStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package com.meta.cp4m.store;

import com.meta.cp4m.Identifier;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.util.List;
Expand All @@ -28,4 +29,6 @@ public interface ChatStore<T extends Message> {
long size();

List<ThreadState<T>> list();

ThreadState<T> get(Identifier threadId);
}
7 changes: 6 additions & 1 deletion src/main/java/com/meta/cp4m/store/MemoryStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class MemoryStore<T extends Message> implements ChatStore<T> {
}

@Override
public ThreadState<T> add(T message) {
public ThreadState<T> add(T message) {
return this.store
.asMap()
.compute(
Expand All @@ -53,4 +53,9 @@ public long size() {
public List<ThreadState<T>> list() {
return store.asMap().values().stream().toList();
}

@Override
public ThreadState<T> get(Identifier threadId){
return this.store.asMap().get(threadId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ void chunkingHappens() throws IOException {
Stream.generate(() -> "0123456789.").limit(300).collect(Collectors.joining(" "));
FBMessage bigMessage =
new FBMessage(
Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER);
Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER,null);
messageHandler.respond(bigMessage);
assertThat(requests.size()).isEqualTo(300);
assertThat(requests).allSatisfy(m -> assertThat(m.body()).contains("0123456789"));
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/com/meta/cp4m/message/MessageTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ void threadId() {
Identifier id0 = Identifier.from("0");
Identifier id1 = Identifier.from("1");
Identifier id2 = Identifier.from("2");
Message message = new FBMessage(timestamp, id0, id1, id2, "", Message.Role.ASSISTANT);
Message response = new FBMessage(timestamp, id0, id2, id1, "", Message.Role.ASSISTANT);
Message message = new FBMessage(timestamp, id0, id1, id2, "", Message.Role.ASSISTANT, null);
Message response = new FBMessage(timestamp, id0, id2, id1, "", Message.Role.ASSISTANT, null);
assertThat(message.threadId()).isEqualTo(response.threadId());

message =
Expand All @@ -33,15 +33,15 @@ void threadId() {
Identifier.from("12"),
Identifier.from("34"),
"",
Message.Role.ASSISTANT);
Message.Role.ASSISTANT, null);
response =
new FBMessage(
timestamp,
id0,
Identifier.from("1"),
Identifier.from("234"),
"",
Message.Role.ASSISTANT);
Message.Role.ASSISTANT, null);
assertThat(message.threadId()).isNotEqualTo(response.threadId());
}
}
4 changes: 2 additions & 2 deletions src/test/java/com/meta/cp4m/message/ServiceTestHarness.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ public URI webserverURI() {

public URI serviceURI() {
try {
return URIBuilder.localhost()
return URIBuilder.loopbackAddress()
.appendPath(SERVICE_PATH)
.setScheme("http")
.setPort(servicePort())
.build();
} catch (URISyntaxException | UnknownHostException e) {
} catch (URISyntaxException e) {
// this should be impossible
throw new RuntimeException(e);
}
Expand Down
4 changes: 3 additions & 1 deletion src/test/java/com/meta/cp4m/store/MemoryStoreTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ void test() {
Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER);
thread = memoryStore.add(message2);
assertThat(memoryStore.size()).isEqualTo(1);
assertThat(thread.messages()).hasSize(2).contains(message, message2);
assertThat(thread.messages()).hasSize(2);
assertThat(thread.messages().get(0).instanceId()).isSameAs(message.instanceId());
assertThat(thread.messages().get(1).instanceId()).isSameAs(message2.instanceId());

FBMessage message3 =
messageFactory.newMessage(
Expand Down
Loading
Loading