Skip to content

Commit

Permalink
Add a Redis based chat memory store
Browse files Browse the repository at this point in the history
Co-authored-by: Georgios Andrianakis <[email protected]>>
  • Loading branch information
2 people authored and geoand committed Apr 1, 2024
1 parent 390955e commit 21accb4
Show file tree
Hide file tree
Showing 16 changed files with 805 additions and 1 deletion.
71 changes: 71 additions & 0 deletions chatmemorystore-redis/deployment/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-chatmemorystore-redis-parent</artifactId>
<version>999-SNAPSHOT</version>
</parent>
<artifactId>quarkus-langchain4j-chatmemorystore-redis-deployment</artifactId>
<name>Quarkus Langchain4j - Redis Chat Memory Store - Deployment</name>
<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-arc-deployment</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-chatmemorystore-redis</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-redis-client-deployment</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-core-deployment</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>${assertj.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.wiremock</groupId>
<artifactId>wiremock-standalone</artifactId>
<version>${wiremock.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai-deployment</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<annotationProcessorPaths>
<path>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-extension-processor</artifactId>
<version>${quarkus.version}</version>
</path>
</annotationProcessorPaths>
</configuration>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.quarkiverse.langchain4j.chatmemorystore.redis.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.chatmemorystore.redis")
public interface RedisChatMemoryStoreBuildTimeConfig {

/**
* The name of the Redis client to use. These clients are configured by means of the `redis-client` extension.
* If unspecified, it will use the default Redis client.
*/
Optional<String> clientName();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package io.quarkiverse.langchain4j.chatmemorystore.redis.deployment;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Default;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.ClassType;
import org.jboss.jandex.DotName;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.quarkiverse.langchain4j.chatmemorystore.RedisChatMemoryStore;
import io.quarkiverse.langchain4j.chatmemorystore.redis.runtime.RedisChatMemoryStoreRecorder;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.FeatureBuildItem;
import io.quarkus.redis.client.RedisClientName;
import io.quarkus.redis.datasource.RedisDataSource;
import io.quarkus.redis.deployment.client.RequestedRedisClientBuildItem;
import io.quarkus.redis.runtime.client.config.RedisConfig;

class RedisChatMemoryStoreProcessor {

public static final DotName REDIS_CHAT_MEMORY_STORE = DotName.createSimple(RedisChatMemoryStore.class);
private static final String FEATURE = "langchain4j-chatmemorystore-redis";

@BuildStep
FeatureBuildItem feature() {
return new FeatureBuildItem(FEATURE);
}

@BuildStep
public RequestedRedisClientBuildItem requestRedisClient(RedisChatMemoryStoreBuildTimeConfig config) {
return new RequestedRedisClientBuildItem(config.clientName().orElse(RedisConfig.DEFAULT_CLIENT_NAME));
}

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
public void createMemoryStoreBean(
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer,
BuildProducer<SyntheticBeanBuildItem> beanProducer,
RedisChatMemoryStoreRecorder recorder,
RedisChatMemoryStoreBuildTimeConfig buildTimeConfig) {
String clientName = buildTimeConfig.clientName().orElse(null);
AnnotationInstance redisClientQualifier;
if (clientName == null) {
redisClientQualifier = AnnotationInstance.builder(Default.class).build();
} else {
redisClientQualifier = AnnotationInstance.builder(RedisClientName.class)
.add("value", clientName)
.build();
}
beanProducer.produce(SyntheticBeanBuildItem
.configure(REDIS_CHAT_MEMORY_STORE)
.types(ClassType.create(ChatMemoryStore.class))
.setRuntimeInit()
.scope(ApplicationScoped.class)
.addInjectionPoint(ClassType.create(DotName.createSimple(RedisDataSource.class)),
redisClientQualifier)
.createWith(recorder.chatMemoryStoreFunction(clientName))
.done());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.quarkiverse.langchain4j.chatmemorystore.redis.test;

import static org.assertj.core.api.Assertions.as;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.InstanceOfAssertFactories.list;
import static org.assertj.core.api.InstanceOfAssertFactories.map;

import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import org.assertj.core.api.InstanceOfAssertFactory;
import org.assertj.core.api.ListAssert;
import org.assertj.core.api.MapAssert;

import com.fasterxml.jackson.core.type.TypeReference;

class MessageAssertUtils {

static final TypeReference<Map<String, Object>> MAP_TYPE_REF = new TypeReference<>() {
};
private static final InstanceOfAssertFactory<Map, MapAssert<String, String>> MAP_STRING_STRING = map(String.class,
String.class);
private static final InstanceOfAssertFactory<List, ListAssert<Map>> LIST_MAP = list(Map.class);

static void assertSingleRequestMessage(Map<String, Object> requestAsMap, String value) {
assertMessages(requestAsMap, (listOfMessages -> {
assertThat(listOfMessages).singleElement(as(MAP_STRING_STRING)).satisfies(message -> {
assertThat(message)
.containsEntry("role", "user")
.containsEntry("content", value);
});
}));
}

static void assertMultipleRequestMessage(Map<String, Object> requestAsMap, List<MessageContent> messageContents) {
assertMessages(requestAsMap, listOfMessages -> {
assertThat(listOfMessages).asInstanceOf(LIST_MAP).hasSize(messageContents.size()).satisfies(l -> {
for (int i = 0; i < messageContents.size(); i++) {
MessageContent messageContent = messageContents.get(i);
assertThat((Map<String, String>) l.get(i)).satisfies(message -> {
assertThat(message)
.containsEntry("role", messageContent.getRole());
if (messageContent.getContent() == null) {
if (message.containsKey("content")) {
assertThat(message).containsEntry("content", null);
}
} else {
assertThat(message).containsEntry("content", messageContent.getContent());
}

});
}
});
});
}

@SuppressWarnings("rawtypes")
static void assertMessages(Map<String, Object> requestAsMap, Consumer<List<? extends Map>> messagesAssertions) {
assertThat(requestAsMap).hasEntrySatisfying("messages",
o -> assertThat(o).asInstanceOf(list(Map.class)).satisfies(messagesAssertions));
}

static class MessageContent {
private final String role;
private final String content;

public MessageContent(String role, String content) {
this.role = role;
this.content = content;
}

public String getRole() {
return role;
}

public String getContent() {
return content;
}
}
}
Loading

0 comments on commit 21accb4

Please sign in to comment.