From 21accb4d1af0203b01933a484565d6a4d709cbd4 Mon Sep 17 00:00:00 2001 From: Sebastien Blanc Date: Fri, 29 Mar 2024 13:58:26 +0100 Subject: [PATCH] Add a Redis based chat memory store Co-authored-by: Georgios Andrianakis > --- chatmemorystore-redis/deployment/pom.xml | 71 ++++++ .../RedisChatMemoryStoreBuildTimeConfig.java | 19 ++ .../RedisChatMemoryStoreProcessor.java | 67 ++++++ .../redis/test/MessageAssertUtils.java | 81 +++++++ .../redis/test/RedisChatMemoryStoreTest.java | 225 ++++++++++++++++++ .../redis/test/WiremockUtils.java | 83 +++++++ .../src/test/resources/chat/default.json | 21 ++ chatmemorystore-redis/pom.xml | 17 ++ chatmemorystore-redis/runtime/pom.xml | 84 +++++++ .../chatmemorystore/ChatMessageCodec.java | 48 ++++ .../chatmemorystore/RedisChatMemoryStore.java | 45 ++++ .../runtime/RedisChatMemoryStoreRecorder.java | 31 +++ .../src/main/resources/META-INF/beans.xml | 0 .../resources/META-INF/quarkus-extension.yaml | 11 + .../ROOT/pages/includes/attributes.adoc | 2 +- pom.xml | 1 + 16 files changed, 805 insertions(+), 1 deletion(-) create mode 100644 chatmemorystore-redis/deployment/pom.xml create mode 100644 chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreBuildTimeConfig.java create mode 100644 chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreProcessor.java create mode 100644 chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/MessageAssertUtils.java create mode 100644 chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/RedisChatMemoryStoreTest.java create mode 100644 chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/WiremockUtils.java create mode 100644 chatmemorystore-redis/deployment/src/test/resources/chat/default.json create mode 100644 chatmemorystore-redis/pom.xml create mode 100644 chatmemorystore-redis/runtime/pom.xml create mode 100644 chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/ChatMessageCodec.java create mode 100644 chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/RedisChatMemoryStore.java create mode 100644 chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/runtime/RedisChatMemoryStoreRecorder.java create mode 100644 chatmemorystore-redis/runtime/src/main/resources/META-INF/beans.xml create mode 100644 chatmemorystore-redis/runtime/src/main/resources/META-INF/quarkus-extension.yaml diff --git a/chatmemorystore-redis/deployment/pom.xml b/chatmemorystore-redis/deployment/pom.xml new file mode 100644 index 000000000..93097bd41 --- /dev/null +++ b/chatmemorystore-redis/deployment/pom.xml @@ -0,0 +1,71 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-chatmemorystore-redis-parent + 999-SNAPSHOT + + quarkus-langchain4j-chatmemorystore-redis-deployment + Quarkus Langchain4j - Redis Chat Memory Store - Deployment + + + io.quarkus + quarkus-arc-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-chatmemorystore-redis + ${project.version} + + + io.quarkus + quarkus-redis-client-deployment + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core-deployment + ${project.version} + + + io.quarkus + quarkus-junit5-internal + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.wiremock + wiremock-standalone + ${wiremock.version} + test + + + io.quarkiverse.langchain4j + quarkus-langchain4j-openai-deployment + ${project.version} + test + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + + diff --git a/chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreBuildTimeConfig.java b/chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreBuildTimeConfig.java new file mode 100644 index 000000000..089165de2 --- /dev/null +++ b/chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreBuildTimeConfig.java @@ -0,0 +1,19 @@ +package io.quarkiverse.langchain4j.chatmemorystore.redis.deployment; + +import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; + +@ConfigRoot(phase = BUILD_TIME) +@ConfigMapping(prefix = "quarkus.langchain4j.chatmemorystore.redis") +public interface RedisChatMemoryStoreBuildTimeConfig { + + /** + * The name of the Redis client to use. These clients are configured by means of the `redis-client` extension. + * If unspecified, it will use the default Redis client. + */ + Optional clientName(); +} diff --git a/chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreProcessor.java b/chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreProcessor.java new file mode 100644 index 000000000..fe28447c8 --- /dev/null +++ b/chatmemorystore-redis/deployment/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/deployment/RedisChatMemoryStoreProcessor.java @@ -0,0 +1,67 @@ +package io.quarkiverse.langchain4j.chatmemorystore.redis.deployment; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Default; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.ClassType; +import org.jboss.jandex.DotName; + +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkiverse.langchain4j.chatmemorystore.RedisChatMemoryStore; +import io.quarkiverse.langchain4j.chatmemorystore.redis.runtime.RedisChatMemoryStoreRecorder; +import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.SyntheticBeanBuildItem; +import io.quarkus.deployment.annotations.BuildProducer; +import io.quarkus.deployment.annotations.BuildStep; +import io.quarkus.deployment.annotations.ExecutionTime; +import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.FeatureBuildItem; +import io.quarkus.redis.client.RedisClientName; +import io.quarkus.redis.datasource.RedisDataSource; +import io.quarkus.redis.deployment.client.RequestedRedisClientBuildItem; +import io.quarkus.redis.runtime.client.config.RedisConfig; + +class RedisChatMemoryStoreProcessor { + + public static final DotName REDIS_CHAT_MEMORY_STORE = DotName.createSimple(RedisChatMemoryStore.class); + private static final String FEATURE = "langchain4j-chatmemorystore-redis"; + + @BuildStep + FeatureBuildItem feature() { + return new FeatureBuildItem(FEATURE); + } + + @BuildStep + public RequestedRedisClientBuildItem requestRedisClient(RedisChatMemoryStoreBuildTimeConfig config) { + return new RequestedRedisClientBuildItem(config.clientName().orElse(RedisConfig.DEFAULT_CLIENT_NAME)); + } + + @BuildStep + @Record(ExecutionTime.RUNTIME_INIT) + public void createMemoryStoreBean( + BuildProducer additionalBeanProducer, + BuildProducer beanProducer, + RedisChatMemoryStoreRecorder recorder, + RedisChatMemoryStoreBuildTimeConfig buildTimeConfig) { + String clientName = buildTimeConfig.clientName().orElse(null); + AnnotationInstance redisClientQualifier; + if (clientName == null) { + redisClientQualifier = AnnotationInstance.builder(Default.class).build(); + } else { + redisClientQualifier = AnnotationInstance.builder(RedisClientName.class) + .add("value", clientName) + .build(); + } + beanProducer.produce(SyntheticBeanBuildItem + .configure(REDIS_CHAT_MEMORY_STORE) + .types(ClassType.create(ChatMemoryStore.class)) + .setRuntimeInit() + .scope(ApplicationScoped.class) + .addInjectionPoint(ClassType.create(DotName.createSimple(RedisDataSource.class)), + redisClientQualifier) + .createWith(recorder.chatMemoryStoreFunction(clientName)) + .done()); + } + +} diff --git a/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/MessageAssertUtils.java b/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/MessageAssertUtils.java new file mode 100644 index 000000000..ed14f9ee6 --- /dev/null +++ b/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/MessageAssertUtils.java @@ -0,0 +1,81 @@ +package io.quarkiverse.langchain4j.chatmemorystore.redis.test; + +import static org.assertj.core.api.Assertions.as; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.list; +import static org.assertj.core.api.InstanceOfAssertFactories.map; + +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.assertj.core.api.InstanceOfAssertFactory; +import org.assertj.core.api.ListAssert; +import org.assertj.core.api.MapAssert; + +import com.fasterxml.jackson.core.type.TypeReference; + +class MessageAssertUtils { + + static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + private static final InstanceOfAssertFactory> MAP_STRING_STRING = map(String.class, + String.class); + private static final InstanceOfAssertFactory> LIST_MAP = list(Map.class); + + static void assertSingleRequestMessage(Map requestAsMap, String value) { + assertMessages(requestAsMap, (listOfMessages -> { + assertThat(listOfMessages).singleElement(as(MAP_STRING_STRING)).satisfies(message -> { + assertThat(message) + .containsEntry("role", "user") + .containsEntry("content", value); + }); + })); + } + + static void assertMultipleRequestMessage(Map requestAsMap, List messageContents) { + assertMessages(requestAsMap, listOfMessages -> { + assertThat(listOfMessages).asInstanceOf(LIST_MAP).hasSize(messageContents.size()).satisfies(l -> { + for (int i = 0; i < messageContents.size(); i++) { + MessageContent messageContent = messageContents.get(i); + assertThat((Map) l.get(i)).satisfies(message -> { + assertThat(message) + .containsEntry("role", messageContent.getRole()); + if (messageContent.getContent() == null) { + if (message.containsKey("content")) { + assertThat(message).containsEntry("content", null); + } + } else { + assertThat(message).containsEntry("content", messageContent.getContent()); + } + + }); + } + }); + }); + } + + @SuppressWarnings("rawtypes") + static void assertMessages(Map requestAsMap, Consumer> messagesAssertions) { + assertThat(requestAsMap).hasEntrySatisfying("messages", + o -> assertThat(o).asInstanceOf(list(Map.class)).satisfies(messagesAssertions)); + } + + static class MessageContent { + private final String role; + private final String content; + + public MessageContent(String role, String content) { + this.role = role; + this.content = content; + } + + public String getRole() { + return role; + } + + public String getContent() { + return content; + } + } +} diff --git a/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/RedisChatMemoryStoreTest.java b/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/RedisChatMemoryStoreTest.java new file mode 100644 index 000000000..1487733df --- /dev/null +++ b/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/RedisChatMemoryStoreTest.java @@ -0,0 +1,225 @@ +package io.quarkiverse.langchain4j.chatmemorystore.redis.test; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static dev.langchain4j.data.message.ChatMessageType.AI; +import static dev.langchain4j.data.message.ChatMessageType.USER; +import static io.quarkiverse.langchain4j.chatmemorystore.redis.test.MessageAssertUtils.assertMultipleRequestMessage; +import static io.quarkiverse.langchain4j.chatmemorystore.redis.test.MessageAssertUtils.assertSingleRequestMessage; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.tuple; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkiverse.langchain4j.ChatMemoryRemover; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.chatmemorystore.RedisChatMemoryStore; +import io.quarkus.redis.datasource.RedisDataSource; +import io.quarkus.test.QuarkusUnitTest; + +public class RedisChatMemoryStoreTest { + + public static final int FIRST_MEMORY_ID = 1; + public static final int SECOND_MEMORY_ID = 2; + private static final int WIREMOCK_PORT = 8089; + private static final String API_KEY = "test"; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub(API_KEY)); + } + + @RegisterAiService + @ApplicationScoped + interface ChatWithSeparateMemoryForEachUser { + + String chat(@MemoryId int memoryId, @UserMessage String userMessage); + } + + @Inject + ChatMemoryStore chatMemoryStore; + + @Inject + RedisDataSource redisDataSource; + + @Inject + ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser; + + @Test + void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException { + // assert the bean type is correct + assertThat(chatMemoryStore).isInstanceOf(RedisChatMemoryStore.class); + + /* **** First request for user 1 **** */ + String firstMessageFromFirstUser = "Hello, my name is Klaus"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(API_KEY, + "Nice to meet you Klaus")); + String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID, firstMessageFromFirstUser); + + // assert response + assertThat(firstAiResponseToFirstUser).isEqualTo("Nice to meet you Klaus"); + + // assert request + assertSingleRequestMessage(getRequestAsMap(), firstMessageFromFirstUser); + + // assert chat memory + assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).hasSize(2) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser)); + + /* **** First request for user 2 **** */ + wireMockServer.resetRequests(); + + String firstMessageFromSecondUser = "Hello, my name is Francine"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(API_KEY, + "Nice to meet you Francine")); + String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID, + firstMessageFromSecondUser); + + // assert response + assertThat(firstAiResponseToSecondUser).isEqualTo("Nice to meet you Francine"); + + // assert request + assertSingleRequestMessage(getRequestAsMap(), firstMessageFromSecondUser); + + // assert chat memory + assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).hasSize(2) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser)); + + /* **** Second request for user 1 **** */ + wireMockServer.resetRequests(); + + String secondsMessageFromFirstUser = "What is my name?"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(API_KEY, + "Your name is Klaus")); + String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID, + secondsMessageFromFirstUser); + + // assert response + assertThat(secondAiMessageToFirstUser).contains("Klaus"); + + // assert request + assertMultipleRequestMessage(getRequestAsMap(), + List.of( + new MessageAssertUtils.MessageContent("user", firstMessageFromFirstUser), + new MessageAssertUtils.MessageContent("assistant", firstAiResponseToFirstUser), + new MessageAssertUtils.MessageContent("user", secondsMessageFromFirstUser))); + + // assert chat memory + assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).hasSize(4) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser), + tuple(USER, secondsMessageFromFirstUser), tuple(AI, secondAiMessageToFirstUser)); + + /* **** Second request for user 2 **** */ + wireMockServer.resetRequests(); + + String secondsMessageFromSecondUser = "What is my name?"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(API_KEY, + "Your name is Francine")); + String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID, + secondsMessageFromSecondUser); + + // assert response + assertThat(secondAiMessageToSecondUser).contains("Francine"); + + // assert request + assertMultipleRequestMessage(getRequestAsMap(), + List.of( + new MessageAssertUtils.MessageContent("user", firstMessageFromSecondUser), + new MessageAssertUtils.MessageContent("assistant", firstAiResponseToSecondUser), + new MessageAssertUtils.MessageContent("user", secondsMessageFromSecondUser))); + + // assert chat memory + assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).hasSize(4) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser), + tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser)); + + // assert our chat memory is used + assertThat(redisDataSource.key().exists("" + FIRST_MEMORY_ID, "" + SECOND_MEMORY_ID)).isEqualTo(2); + + // remove the first entry + ChatMemoryRemover.remove(chatWithSeparateMemoryForEachUser, FIRST_MEMORY_ID); + assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).isEmpty(); + assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).isNotEmpty(); + + // remove the second entry + ChatMemoryRemover.remove(chatWithSeparateMemoryForEachUser, SECOND_MEMORY_ID); + assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).isEmpty(); + assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).isEmpty(); + + // now assert that our store was used for delete + assertThat(redisDataSource.key().exists("" + FIRST_MEMORY_ID, "" + SECOND_MEMORY_ID)).isEqualTo(0); + } + + private Map getRequestAsMap() throws IOException { + return getRequestAsMap(getRequestBody()); + } + + private Map getRequestAsMap(byte[] body) throws IOException { + return mapper.readValue(body, MAP_TYPE_REF); + } + + private byte[] getRequestBody() { + assertThat(wireMockServer.getAllServeEvents()).hasSize(1); + ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test + return getRequestBody(serveEvent); + } + + private byte[] getRequestBody(ServeEvent serveEvent) { + LoggedRequest request = serveEvent.getRequest(); + assertThat(request.getBody()).isNotEmpty(); + return request.getBody(); + } +} diff --git a/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/WiremockUtils.java b/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/WiremockUtils.java new file mode 100644 index 000000000..b6aa3aee3 --- /dev/null +++ b/chatmemorystore-redis/deployment/src/test/java/io/quarkiverse/langchain4j/chatmemorystore/redis/test/WiremockUtils.java @@ -0,0 +1,83 @@ +package io.quarkiverse.langchain4j.chatmemorystore.redis.test; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; + +import com.github.tomakehurst.wiremock.client.MappingBuilder; +import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder; +import com.github.tomakehurst.wiremock.matching.RequestPatternBuilder; + +import io.quarkus.bootstrap.classloading.QuarkusClassLoader; + +public class WiremockUtils { + + private static final String DEFAULT_CHAT_MESSAGE_CONTENT = "Hello there, how may I assist you today?"; + private static final String CHAT_MESSAGE_CONTENT_TEMPLATE; + private static final String DEFAULT_CHAT_RESPONSE_BODY; + public static final ResponseDefinitionBuilder CHAT_RESPONSE_WITHOUT_BODY; + private static final ResponseDefinitionBuilder DEFAULT_CHAT_RESPONSE; + + static { + try (InputStream is = getClassLoader().getResourceAsStream("chat/default.json")) { + CHAT_MESSAGE_CONTENT_TEMPLATE = new String(is.readAllBytes(), StandardCharsets.UTF_8); + DEFAULT_CHAT_RESPONSE_BODY = String.format(CHAT_MESSAGE_CONTENT_TEMPLATE, DEFAULT_CHAT_MESSAGE_CONTENT); + CHAT_RESPONSE_WITHOUT_BODY = aResponse().withHeader("Content-Type", "application/json"); + DEFAULT_CHAT_RESPONSE = CHAT_RESPONSE_WITHOUT_BODY + .withBody(DEFAULT_CHAT_RESPONSE_BODY); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static ClassLoader getClassLoader() { + ClassLoader loader = Thread.currentThread().getContextClassLoader(); + while (loader instanceof QuarkusClassLoader) { + loader = loader.getParent(); + } + return loader; + } + + private static ResponseDefinitionBuilder defaultChatCompletionResponse() { + return DEFAULT_CHAT_RESPONSE; + } + + public static MappingBuilder chatCompletionMapping(String token) { + return post(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer " + token)); + } + + public static RequestPatternBuilder chatCompletionRequestPattern(String token) { + return postRequestedFor(urlEqualTo("/v1/chat/completions")) + .withHeader("Authorization", equalTo("Bearer " + token)); + } + + public static RequestPatternBuilder chatCompletionRequestPattern(String token, String organization) { + return chatCompletionRequestPattern(token) + .withHeader("OpenAI-Organization", equalTo(organization)); + } + + public static MappingBuilder moderationMapping(String token) { + return post(urlEqualTo("/v1/moderations")) + .withHeader("Authorization", equalTo("Bearer " + token)); + } + + public static MappingBuilder defaultChatCompletionsStub(String token) { + return chatCompletionMapping(token) + .willReturn(defaultChatCompletionResponse()); + } + + public static MappingBuilder chatCompletionsMessageContent(String token, String messageContent) { + return chatCompletionMapping(token) + .willReturn( + CHAT_RESPONSE_WITHOUT_BODY.withBody(String.format(CHAT_MESSAGE_CONTENT_TEMPLATE, messageContent))); + } + +} diff --git a/chatmemorystore-redis/deployment/src/test/resources/chat/default.json b/chatmemorystore-redis/deployment/src/test/resources/chat/default.json new file mode 100644 index 000000000..8107e7b26 --- /dev/null +++ b/chatmemorystore-redis/deployment/src/test/resources/chat/default.json @@ -0,0 +1,21 @@ +{ + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "gpt-3.5-turbo-instruct", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "%s" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } +} diff --git a/chatmemorystore-redis/pom.xml b/chatmemorystore-redis/pom.xml new file mode 100644 index 000000000..e7a28dfdc --- /dev/null +++ b/chatmemorystore-redis/pom.xml @@ -0,0 +1,17 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-parent + 999-SNAPSHOT + + quarkus-langchain4j-chatmemorystore-redis-parent + pom + Quarkus Langchain4j - Redis Chat Memory Store - Parent + + deployment + runtime + + diff --git a/chatmemorystore-redis/runtime/pom.xml b/chatmemorystore-redis/runtime/pom.xml new file mode 100644 index 000000000..9e2469a24 --- /dev/null +++ b/chatmemorystore-redis/runtime/pom.xml @@ -0,0 +1,84 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-chatmemorystore-redis-parent + 999-SNAPSHOT + + quarkus-langchain4j-chatmemorystore-redis + Quarkus Langchain4j - Redis Chat Memory Store - Runtime + Provides a Redis Chat Memory Store + + + io.quarkus + quarkus-arc + + + io.quarkus + quarkus-redis-client + + + io.quarkus + quarkus-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-core + ${project.version} + + + + + + io.quarkus + quarkus-extension-maven-plugin + ${quarkus.version} + + + compile + + extension-descriptor + + + ${project.groupId}:${project.artifactId}-deployment:${project.version} + + + + + + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${quarkus.version} + + + + + + maven-jar-plugin + + + generate-codestart-jar + generate-resources + + jar + + + ${project.basedir}/src/main + + codestarts/** + + codestarts + true + + + + + + + diff --git a/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/ChatMessageCodec.java b/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/ChatMessageCodec.java new file mode 100644 index 000000000..299bd107b --- /dev/null +++ b/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/ChatMessageCodec.java @@ -0,0 +1,48 @@ +package io.quarkiverse.langchain4j.chatmemorystore; + +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.Collections; +import java.util.List; + +import jakarta.inject.Singleton; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; + +import dev.langchain4j.data.message.ChatMessage; +import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; +import io.quarkus.redis.datasource.codecs.Codec; + +@Singleton +public class ChatMessageCodec implements Codec { + private static final TypeReference> MESSAGE_LIST_TYPE = new TypeReference<>() { + }; + + @Override + public boolean canHandle(Type clazz) { + return MESSAGE_LIST_TYPE.getType().equals(clazz); + } + + @Override + public byte[] encode(Object item) { + try { + return QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.writeValueAsBytes(item); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + @Override + public Object decode(byte[] item) { + if (item == null) { + return Collections.emptyList(); + } + try { + return QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readValue(item, MESSAGE_LIST_TYPE); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/RedisChatMemoryStore.java b/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/RedisChatMemoryStore.java new file mode 100644 index 000000000..f48f05523 --- /dev/null +++ b/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/RedisChatMemoryStore.java @@ -0,0 +1,45 @@ +package io.quarkiverse.langchain4j.chatmemorystore; + +import java.util.Collections; +import java.util.List; + +import com.fasterxml.jackson.core.type.TypeReference; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkus.redis.datasource.RedisDataSource; +import io.quarkus.redis.datasource.keys.KeyCommands; +import io.quarkus.redis.datasource.value.ValueCommands; + +public class RedisChatMemoryStore implements ChatMemoryStore { + + private final ValueCommands> valueCommands; + private final KeyCommands keyCommands; + + public RedisChatMemoryStore(RedisDataSource redisDataSource) { + this.valueCommands = redisDataSource.value(new TypeReference>() { + }); + this.keyCommands = redisDataSource.key(String.class); + } + + @Override + public void deleteMessages(Object memoryId) { + keyCommands.del(memoryId.toString()); + + } + + @Override + public List getMessages(Object memoryId) { + List chatMessages = valueCommands.get(memoryId.toString()); + if (chatMessages != null) { + return chatMessages; + } else { + return Collections.emptyList(); + } + } + + @Override + public void updateMessages(Object memoryId, List messages) { + valueCommands.set(memoryId.toString(), messages); + } +} diff --git a/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/runtime/RedisChatMemoryStoreRecorder.java b/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/runtime/RedisChatMemoryStoreRecorder.java new file mode 100644 index 000000000..4d948a3db --- /dev/null +++ b/chatmemorystore-redis/runtime/src/main/java/io/quarkiverse/langchain4j/chatmemorystore/redis/runtime/RedisChatMemoryStoreRecorder.java @@ -0,0 +1,31 @@ +package io.quarkiverse.langchain4j.chatmemorystore.redis.runtime; + +import java.util.function.Function; + +import jakarta.enterprise.inject.Default; + +import io.quarkiverse.langchain4j.chatmemorystore.RedisChatMemoryStore; +import io.quarkus.arc.SyntheticCreationalContext; +import io.quarkus.redis.client.RedisClientName; +import io.quarkus.redis.datasource.RedisDataSource; +import io.quarkus.runtime.annotations.Recorder; + +@Recorder +public class RedisChatMemoryStoreRecorder { + public Function, RedisChatMemoryStore> chatMemoryStoreFunction( + String clientName) { + return new Function<>() { + @Override + public RedisChatMemoryStore apply(SyntheticCreationalContext context) { + RedisDataSource dataSource; + if (clientName == null) { + dataSource = context.getInjectedReference(RedisDataSource.class, new Default.Literal()); + } else { + dataSource = context.getInjectedReference(RedisDataSource.class, + new RedisClientName.Literal(clientName)); + } + return new RedisChatMemoryStore(dataSource); + } + }; + } +} diff --git a/chatmemorystore-redis/runtime/src/main/resources/META-INF/beans.xml b/chatmemorystore-redis/runtime/src/main/resources/META-INF/beans.xml new file mode 100644 index 000000000..e69de29bb diff --git a/chatmemorystore-redis/runtime/src/main/resources/META-INF/quarkus-extension.yaml b/chatmemorystore-redis/runtime/src/main/resources/META-INF/quarkus-extension.yaml new file mode 100644 index 000000000..da5552b8b --- /dev/null +++ b/chatmemorystore-redis/runtime/src/main/resources/META-INF/quarkus-extension.yaml @@ -0,0 +1,11 @@ +name: Quarkus Langchain4j - Redis Chat Memory Store +artifact: ${project.groupId}:${project.artifactId}:${project.version} +description: Provides a Redis Chat Memory Store +metadata: + keywords: + - ai + - langchain4j + - redis + categories: + - "miscellaneous" + status: "experimental" diff --git a/docs/modules/ROOT/pages/includes/attributes.adoc b/docs/modules/ROOT/pages/includes/attributes.adoc index 9e2c909d8..0471ff587 100644 --- a/docs/modules/ROOT/pages/includes/attributes.adoc +++ b/docs/modules/ROOT/pages/includes/attributes.adoc @@ -1,3 +1,3 @@ :project-version: 0.10.2 :langchain4j-version: 0.29.1 -:examples-dir: ./../examples/ +:examples-dir: ./../examples/ \ No newline at end of file diff --git a/pom.xml b/pom.xml index 3cd452b53..adde79c77 100644 --- a/pom.xml +++ b/pom.xml @@ -13,6 +13,7 @@ Quarkus LangChain4j - Parent bam + chatmemorystore-redis chroma cohere core