From cf17b8cbdb725c2a5bc28865dc345fb6f6d9feff Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 5 Dec 2023 15:45:48 +0200 Subject: [PATCH] Make @RegisterAiService beans request scoped by default This is done because otherwise the chat memory does not get cleared properly. Furthermore, add a way to remove memory entries when the service goes out of scope Fixes: #95 --- .../deployment/AiServicesProcessor.java | 48 +++- .../DeclarativeAiServiceBuildItem.java | 10 +- .../langchain4j/RegisterAiService.java | 6 +- .../RemovableChatMemoryProvider.java | 13 + .../AiServiceMethodImplementationSupport.java | 1 + .../DeclarativeAiServiceBeanDestroyer.java | 24 ++ .../aiservice/QuarkusAiServiceContext.java | 31 +++ .../aiservices/AuditingServiceTest.java | 1 + .../BeanDeclarativeAiServicesTest.java | 1 + .../aiservices/DeclarativeAiServicesTest.java | 5 + .../aiservices/RemovableChatMemoryTest.java | 258 ++++++++++++++++++ samples/chatbot/pom.xml | 6 +- .../sample/chatbot/ChatBotWebSocket.java | 9 +- samples/csv-chatbot/pom.xml | 6 +- .../sample/chatbot/ChatBotWebSocket.java | 14 +- .../langchain4j/sample/ChatMemoryBean.java | 15 +- 16 files changed, 416 insertions(+), 32 deletions(-) create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceBeanDestroyer.java create mode 100644 openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index bc47fde56..2e35fa199 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -23,7 +23,6 @@ import java.util.function.Predicate; import java.util.stream.Collectors; -import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.inject.Instance; import org.jboss.jandex.AnnotationInstance; @@ -50,6 +49,7 @@ import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport; +import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer; import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; @@ -60,6 +60,8 @@ import io.quarkus.arc.deployment.AdditionalBeanBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; +import io.quarkus.arc.processor.BuiltinScope; +import io.quarkus.arc.processor.ScopeInfo; import io.quarkus.builder.item.MultiBuildItem; import io.quarkus.deployment.Capabilities; import io.quarkus.deployment.Capability; @@ -101,6 +103,9 @@ public class AiServicesProcessor { private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod( AiServiceMethodImplementationSupport.class, "implement", Object.class, AiServiceMethodImplementationSupport.Input.class); + + private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_CLOSE = MethodDescriptor.ofMethod( + QuarkusAiServiceContext.class, "close", void.class); public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class); @BuildStep @@ -211,6 +216,9 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer); } + BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo); + ScopeInfo cdiScope = declaredScope != null ? declaredScope.getInfo() : BuiltinScope.REQUEST.getInfo(); + declarativeAiServiceProducer.produce( new DeclarativeAiServiceBuildItem( declarativeAiServiceClassInfo, @@ -218,7 +226,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, toolDotNames, chatMemoryProviderSupplierClassDotName, retrieverSupplierClassDotName, - auditServiceClassSupplierName)); + auditServiceClassSupplierName, + cdiScope)); } if (needChatModelBean) { @@ -285,8 +294,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, toolClassNames, chatMemoryProviderSupplierClassName, retrieverSupplierClassName, auditServiceClassSupplierName))) + .destroyer(DeclarativeAiServiceBeanDestroyer.class) .setRuntimeInit() - .scope(ApplicationScoped.class); + .scope(bi.getCdiScope()); if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed? configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL)); needsChatModelBean = true; @@ -403,8 +413,10 @@ public void handleAiServices(AiServicesRecorder recorder, Set detectedForCreate = new HashSet<>(nameToUsed.keySet()); addCreatedAware(index, detectedForCreate); addIfacesWithMessageAnns(index, detectedForCreate); - detectedForCreate.addAll(declarativeAiServiceItems.stream().map(bi -> bi.getServiceClassInfo().name().toString()) - .collect(Collectors.toList())); + Set registeredAiServiceClassNames = declarativeAiServiceItems.stream() + .map(bi -> bi.getServiceClassInfo().name().toString()).collect( + Collectors.toUnmodifiableSet()); + detectedForCreate.addAll(registeredAiServiceClassNames); Set ifacesForCreate = new HashSet<>(); for (String className : detectedForCreate) { @@ -453,12 +465,18 @@ public void handleAiServices(AiServicesRecorder recorder, methodsToImplement.add(method); } - String implClassName = iface.name().toString() + "$$QuarkusImpl"; - try (ClassCreator classCreator = ClassCreator.builder() + String ifaceName = iface.name().toString(); + String implClassName = ifaceName + "$$QuarkusImpl"; + boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName); + + ClassCreator.Builder classCreatorBuilder = ClassCreator.builder() .classOutput(classOutput) .className(implClassName) - .interfaces(iface.name().toString()) - .build()) { + .interfaces(ifaceName); + if (isRegisteredService) { + classCreatorBuilder.interfaces(AutoCloseable.class); + } + try (ClassCreator classCreator = classCreatorBuilder.build()) { FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class) .setModifiers(Modifier.PRIVATE | Modifier.FINAL) @@ -480,7 +498,7 @@ public void handleAiServices(AiServicesRecorder recorder, MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo)); ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis()); ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO, - mc.load(iface.name().toString()), + mc.load(ifaceName), mc.load(methodId)); ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount()); for (int i = 0; i < methodInfo.parametersCount(); i++) { @@ -498,8 +516,16 @@ public void handleAiServices(AiServicesRecorder recorder, aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo)); } + + if (isRegisteredService) { + MethodCreator mc = classCreator.getMethodCreator( + MethodDescriptor.ofMethod(implClassName, "close", void.class)); + ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis()); + mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle); + mc.returnVoid(); + } } - perClassMetadata.put(iface.name().toString(), new AiServiceClassCreateInfo(perMethodMetadata, implClassName)); + perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName)); // make the constructor accessible reflectively since that is how we create the instance reflectiveClassProducer.produce(ReflectiveClassBuildItem.builder(implClassName).build()); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index fa32d3b91..97194e993 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -5,6 +5,7 @@ import org.jboss.jandex.ClassInfo; import org.jboss.jandex.DotName; +import io.quarkus.arc.processor.ScopeInfo; import io.quarkus.builder.item.MultiBuildItem; /** @@ -19,18 +20,21 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final DotName chatMemoryProviderSupplierClassDotName; private final DotName retrieverSupplierClassDotName; private final DotName auditServiceClassSupplierDotName; + private final ScopeInfo cdiScope; public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName, List toolDotNames, DotName chatMemoryProviderSupplierClassDotName, DotName retrieverSupplierClassDotName, - DotName auditServiceClassSupplierDotName) { + DotName auditServiceClassSupplierDotName, + ScopeInfo cdiScope) { this.serviceClassInfo = serviceClassInfo; this.languageModelSupplierClassDotName = languageModelSupplierClassDotName; this.toolDotNames = toolDotNames; this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName; this.retrieverSupplierClassDotName = retrieverSupplierClassDotName; this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName; + this.cdiScope = cdiScope; } public ClassInfo getServiceClassInfo() { @@ -56,4 +60,8 @@ public DotName getRetrieverSupplierClassDotName() { public DotName getAuditServiceClassSupplierDotName() { return auditServiceClassSupplierDotName; } + + public ScopeInfo getCdiScope() { + return cdiScope; + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 94332e917..c05f58de2 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -7,8 +7,6 @@ import java.lang.annotation.Target; import java.util.function.Supplier; -import jakarta.enterprise.context.ApplicationScoped; - import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; @@ -24,7 +22,9 @@ * while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional), * {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist). *

- * NOTE: The resulting CDI bean is {@link ApplicationScoped}. + * NOTE: The resulting CDI bean is {@link jakarta.enterprise.context.RequestScoped} be default. If you need to change the scope, + * simply annotate the class with a CDI scope. + * CAUTION: When using anything other than the request scope, you need to be very careful with the chat memory implementation. *

* NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated * for the method invocations. diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java new file mode 100644 index 000000000..e309f27d8 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java @@ -0,0 +1,13 @@ +package io.quarkiverse.langchain4j; + +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; + +/** + * Extends {@link ChatMemoryProvider} to allow for removing {@link ChatMemory} + * when it is no longer needed. + */ +public interface RemovableChatMemoryProvider extends ChatMemoryProvider { + + void remove(Object id); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 108ddd68e..b343d3fb2 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -115,6 +115,7 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[] } Object memoryId = memoryId(createInfo, methodArgs).orElse("default"); + context.usedMemoryIds.add(memoryId); if (context.hasChatMemory()) { ChatMemory chatMemory = context.chatMemory(memoryId); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceBeanDestroyer.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceBeanDestroyer.java new file mode 100644 index 000000000..b6816edcc --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceBeanDestroyer.java @@ -0,0 +1,24 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import java.util.Map; + +import jakarta.enterprise.context.spi.CreationalContext; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.BeanDestroyer; + +public class DeclarativeAiServiceBeanDestroyer implements BeanDestroyer { + + private static final Logger log = Logger.getLogger(DeclarativeAiServiceBeanDestroyer.class); + + @Override + public void destroy(AutoCloseable instance, CreationalContext creationalContext, + Map params) { + try { + instance.close(); + } catch (Exception e) { + log.error("Unable to close " + instance); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java index 564d4e787..2d6cb8319 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java @@ -1,13 +1,44 @@ package io.quarkiverse.langchain4j.runtime.aiservice; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + import dev.langchain4j.service.AiServiceContext; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.RemovableChatMemoryProvider; import io.quarkiverse.langchain4j.audit.AuditService; public class QuarkusAiServiceContext extends AiServiceContext { public AuditService auditService; + public Set usedMemoryIds = ConcurrentHashMap.newKeySet(); + public QuarkusAiServiceContext(Class aiServiceClass) { super(aiServiceClass); } + + /** + * This is called by the {@code close} method of AiServices registered with {@link RegisterAiService} + * when the bean's scope is closed + */ + public void close() { + removeChatMemories(); + } + + private void removeChatMemories() { + if (usedMemoryIds.isEmpty()) { + return; + } + RemovableChatMemoryProvider removableChatMemoryProvider = null; + if (chatMemoryProvider instanceof RemovableChatMemoryProvider) { + removableChatMemoryProvider = (RemovableChatMemoryProvider) chatMemoryProvider; + } + for (Object memoryId : usedMemoryIds) { + if (removableChatMemoryProvider != null) { + removableChatMemoryProvider.remove(memoryId); + } + chatMemories.remove(memoryId); + } + } } diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java index d1facbfb6..b65f791af 100644 --- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java @@ -126,6 +126,7 @@ static class Calculator { } @RegisterAiService(tools = Calculator.class) + @Singleton interface Assistant { String chat(String message); diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java index f4586cffb..e8be0abd3 100644 --- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java @@ -127,6 +127,7 @@ public void deleteMessages(Object memoryId) { } @RegisterAiService + @Singleton interface ChatWithSeparateMemoryForEachUser { String chat(@MemoryId int memoryId, @UserMessage String userMessage); diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java index d5d2dd611..8e4a218d1 100644 --- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.Optional; +import jakarta.enterprise.context.control.ActivateRequestContext; import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -105,6 +106,7 @@ interface Assistant { Assistant assistant; @Test + @ActivateRequestContext public void test_simple_instruction_with_single_argument_and_no_annotations() throws IOException { String result = assistant.chat("Tell me a joke about developers"); assertThat(result).isNotBlank(); @@ -129,6 +131,7 @@ interface SentimentAnalyzer { SentimentAnalyzer sentimentAnalyzer; @Test + @ActivateRequestContext void test_extract_enum() throws IOException { wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), "POSITIVE")); @@ -213,6 +216,7 @@ interface AssistantWithCalculator extends Assistant { AssistantWithCalculator assistantWithCalculator; @Test + @ActivateRequestContext void should_execute_tool_then_answer() throws IOException { var firstResponse = """ { @@ -308,6 +312,7 @@ interface ChatWithSeparateMemoryForEachUser { ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser; @Test + @ActivateRequestContext void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException { ChatMemoryStore store = Arc.container().instance(ChatMemoryStore.class).get(); diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java new file mode 100644 index 000000000..f586c8401 --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java @@ -0,0 +1,258 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static dev.langchain4j.data.message.ChatMessageType.AI; +import static dev.langchain4j.data.message.ChatMessageType.USER; +import static org.acme.examples.aiservices.MessageAssertUtils.assertMultipleRequestMessage; +import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.tuple; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.RemovableChatMemoryProvider; +import io.quarkiverse.langchain4j.openai.test.WiremockUtils; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ManagedContext; +import io.quarkus.test.QuarkusUnitTest; + +public class RemovableChatMemoryTest { + + private static final int WIREMOCK_PORT = 8089; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub()); + } + + @ApplicationScoped + public static class ChatMemoryBean implements RemovableChatMemoryProvider { + + static final Map memories = new ConcurrentHashMap<>(); + + @Override + public ChatMemory get(Object memoryId) { + return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder() + .maxMessages(20) + .id(memoryId) + .build()); + } + + @Override + public void remove(Object id) { + memories.remove(id); + } + } + + @RegisterAiService + interface ChatWithSeparateMemoryForEachUser { + + String chat(@MemoryId int memoryId, @UserMessage String userMessage); + } + + @Inject + ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser; + + @Test + void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException { + + ManagedContext requestContext = Arc.container().requestContext(); + + // add a dummy entry that should not affect the chat in any way + ChatMemoryBean.memories.put("DUMMY", new ChatMemory() { + @Override + public Object id() { + return null; + } + + @Override + public void add(ChatMessage message) { + + } + + @Override + public List messages() { + return null; + } + + @Override + public void clear() { + + } + }); + + try { + requestContext.activate(); + testInRequestContext(); + } finally { + requestContext.terminate(); + } + + // since the request context was closed, we should now only have the initial dummy entry + assertThat(ChatMemoryBean.memories).containsOnlyKeys("DUMMY"); + } + + private void testInRequestContext() throws IOException { + int firstMemoryId = 1; + int secondMemoryId = 2; + + /* **** First request for user 1 **** */ + String firstMessageFromFirstUser = "Hello, my name is Klaus"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Nice to meet you Klaus")); + String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, firstMessageFromFirstUser); + + // assert response + assertThat(firstAiResponseToFirstUser).isEqualTo("Nice to meet you Klaus"); + + // assert request + assertSingleRequestMessage(getRequestAsMap(), firstMessageFromFirstUser); + + // assert chat memory + assertThat(ChatMemoryBean.memories.get(firstMemoryId).messages()).hasSize(2) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser)); + + /* **** First request for user 2 **** */ + wireMockServer.resetRequests(); + + String firstMessageFromSecondUser = "Hello, my name is Francine"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Nice to meet you Francine")); + String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId, firstMessageFromSecondUser); + + // assert response + assertThat(firstAiResponseToSecondUser).isEqualTo("Nice to meet you Francine"); + + // assert request + assertSingleRequestMessage(getRequestAsMap(), firstMessageFromSecondUser); + + // assert chat memory + assertThat(ChatMemoryBean.memories.get(secondMemoryId).messages()).hasSize(2) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser)); + + /* **** Second request for user 1 **** */ + wireMockServer.resetRequests(); + + String secondsMessageFromFirstUser = "What is my name?"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Your name is Klaus")); + String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, secondsMessageFromFirstUser); + + // assert response + assertThat(secondAiMessageToFirstUser).contains("Klaus"); + + // assert request + assertMultipleRequestMessage(getRequestAsMap(), + List.of( + new MessageAssertUtils.MessageContent("user", firstMessageFromFirstUser), + new MessageAssertUtils.MessageContent("assistant", firstAiResponseToFirstUser), + new MessageAssertUtils.MessageContent("user", secondsMessageFromFirstUser))); + + // assert chat memory + assertThat(ChatMemoryBean.memories.get(firstMemoryId).messages()).hasSize(4) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser), + tuple(USER, secondsMessageFromFirstUser), tuple(AI, secondAiMessageToFirstUser)); + + /* **** Second request for user 2 **** */ + wireMockServer.resetRequests(); + + String secondsMessageFromSecondUser = "What is my name?"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Your name is Francine")); + String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId, + secondsMessageFromSecondUser); + + // assert response + assertThat(secondAiMessageToSecondUser).contains("Francine"); + + // assert request + assertMultipleRequestMessage(getRequestAsMap(), + List.of( + new MessageAssertUtils.MessageContent("user", firstMessageFromSecondUser), + new MessageAssertUtils.MessageContent("assistant", firstAiResponseToSecondUser), + new MessageAssertUtils.MessageContent("user", secondsMessageFromSecondUser))); + + // assert chat memory + assertThat(ChatMemoryBean.memories.get(secondMemoryId).messages()).hasSize(4) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser), + tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser)); + } + + private Map getRequestAsMap() throws IOException { + return getRequestAsMap(getRequestBody()); + } + + private Map getRequestAsMap(byte[] body) throws IOException { + return mapper.readValue(body, MAP_TYPE_REF); + } + + private byte[] getRequestBody() { + assertThat(wireMockServer.getAllServeEvents()).hasSize(1); + ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test + return getRequestBody(serveEvent); + } + + private byte[] getRequestBody(ServeEvent serveEvent) { + LoggedRequest request = serveEvent.getRequest(); + assertThat(request.getBody()).isNotEmpty(); + return request.getBody(); + } +} diff --git a/samples/chatbot/pom.xml b/samples/chatbot/pom.xml index 3bbdacea1..6849a3bdc 100644 --- a/samples/chatbot/pom.xml +++ b/samples/chatbot/pom.xml @@ -21,6 +21,10 @@ io.quarkus quarkus-websockets + + io.quarkus + quarkus-smallrye-context-propagation + io.quarkiverse.langchain4j quarkus-langchain4j-openai @@ -146,4 +150,4 @@ - \ No newline at end of file + diff --git a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java index d2a58a7d2..d692238e5 100644 --- a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java +++ b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java @@ -6,7 +6,7 @@ import jakarta.websocket.*; import jakarta.websocket.server.ServerEndpoint; -import io.smallrye.mutiny.infrastructure.Infrastructure; +import org.eclipse.microprofile.context.ManagedExecutor; @ServerEndpoint("/chatbot") public class ChatBotWebSocket { @@ -14,12 +14,15 @@ public class ChatBotWebSocket { @Inject Bot bot; + @Inject + ManagedExecutor managedExecutor; + @Inject ChatMemoryBean chatMemoryBean; @OnOpen public void onOpen(Session session) { - Infrastructure.getDefaultExecutor().execute(() -> { + managedExecutor.execute(() -> { String response = bot.chat(session, "hello"); try { session.getBasicRemote().sendText(response); @@ -36,7 +39,7 @@ void onClose(Session session) { @OnMessage public void onMessage(String message, Session session) { - Infrastructure.getDefaultExecutor().execute(() -> { + managedExecutor.execute(() -> { String response = bot.chat(session, message); try { session.getBasicRemote().sendText(response); diff --git a/samples/csv-chatbot/pom.xml b/samples/csv-chatbot/pom.xml index c086a2a33..c3a5fe56e 100644 --- a/samples/csv-chatbot/pom.xml +++ b/samples/csv-chatbot/pom.xml @@ -21,6 +21,10 @@ io.quarkus quarkus-websockets + + io.quarkus + quarkus-smallrye-context-propagation + io.quarkiverse.langchain4j quarkus-langchain4j-openai @@ -155,4 +159,4 @@ - \ No newline at end of file + diff --git a/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java b/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java index 2a98274a3..0f5d6268a 100644 --- a/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java +++ b/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java @@ -3,10 +3,13 @@ import java.io.IOException; import jakarta.inject.Inject; -import jakarta.websocket.*; +import jakarta.websocket.OnClose; +import jakarta.websocket.OnMessage; +import jakarta.websocket.OnOpen; +import jakarta.websocket.Session; import jakarta.websocket.server.ServerEndpoint; -import io.smallrye.mutiny.infrastructure.Infrastructure; +import org.eclipse.microprofile.context.ManagedExecutor; @ServerEndpoint("/chatbot") public class ChatBotWebSocket { @@ -14,12 +17,15 @@ public class ChatBotWebSocket { @Inject MovieMuse bot; + @Inject + ManagedExecutor managedExecutor; + @Inject ChatMemoryBean chatMemoryBean; @OnOpen public void onOpen(Session session) { - Infrastructure.getDefaultExecutor().execute(() -> { + managedExecutor.execute(() -> { String response = bot.chat(session, "hello"); try { session.getBasicRemote().sendText(response); @@ -36,7 +42,7 @@ void onClose(Session session) { @OnMessage public void onMessage(String message, Session session) { - Infrastructure.getDefaultExecutor().execute(() -> { + managedExecutor.execute(() -> { String response = bot.chat(session, message); try { session.getBasicRemote().sendText(response); diff --git a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java index 5f28a29f2..4ab1056c0 100644 --- a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java +++ b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java @@ -3,15 +3,14 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import jakarta.annotation.PreDestroy; -import jakarta.enterprise.context.RequestScoped; +import jakarta.inject.Singleton; import dev.langchain4j.memory.ChatMemory; -import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import io.quarkiverse.langchain4j.RemovableChatMemoryProvider; -@RequestScoped -public class ChatMemoryBean implements ChatMemoryProvider { +@Singleton +public class ChatMemoryBean implements RemovableChatMemoryProvider { private final Map memories = new ConcurrentHashMap<>(); @@ -23,8 +22,8 @@ public ChatMemory get(Object memoryId) { .build()); } - @PreDestroy - public void close() { - memories.clear(); + @Override + public void remove(Object id) { + memories.remove(id); } }