diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
index 2e35fa199..b6ea54884 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
@@ -49,6 +49,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
+import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
@@ -106,6 +107,9 @@ public class AiServicesProcessor {
private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_CLOSE = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "close", void.class);
+
+ private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_REMOVE_CHAT_MEMORY_IDS = MethodDescriptor.ofMethod(
+ QuarkusAiServiceContext.class, "removeChatMemoryIds", void.class, Object[].class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);
@BuildStep
@@ -184,16 +188,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
- DotName chatMemoryProviderSupplierClassDotName = toolDotNames.isEmpty()
- ? Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER
- : Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
+ DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
- if (chatMemoryProviderSupplierClassDotName.equals(
- Langchain4jDotNames.NO_CHAT_MEMORY_PROVIDER_SUPPLIER)) {
- chatMemoryProviderSupplierClassDotName = null;
- } else if (!chatMemoryProviderSupplierClassDotName
+ if (!chatMemoryProviderSupplierClassDotName
.equals(Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER)) {
validateSupplierAndRegisterForReflection(chatMemoryProviderSupplierClassDotName, index,
reflectiveClassProducer);
@@ -313,11 +312,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER));
needsChatMemoryProviderBean = true;
- } else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
- .equals(chatMemoryProviderSupplierClassName)) {
- configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
- new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
- needsChatMemoryProviderBean = true;
}
if (Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER.toString().equals(retrieverSupplierClassName)) {
@@ -472,7 +466,7 @@ public void handleAiServices(AiServicesRecorder recorder,
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(classOutput)
.className(implClassName)
- .interfaces(ifaceName);
+ .interfaces(ifaceName, ChatMemoryRemovable.class.getName());
if (isRegisteredService) {
classCreatorBuilder.interfaces(AutoCloseable.class);
}
@@ -524,6 +518,16 @@ public void handleAiServices(AiServicesRecorder recorder,
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
mc.returnVoid();
}
+
+ {
+ MethodCreator mc = classCreator.getMethodCreator(
+ MethodDescriptor.ofMethod(implClassName, "remove", void.class, Object[].class));
+ ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
+ mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_REMOVE_CHAT_MEMORY_IDS, contextHandle,
+ mc.getMethodParam(0));
+ mc.returnVoid();
+ }
+
}
perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
// make the constructor accessible reflectively since that is how we create the instance
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ChatMemoryBuildConfig.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ChatMemoryBuildConfig.java
new file mode 100644
index 000000000..05cd6220d
--- /dev/null
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ChatMemoryBuildConfig.java
@@ -0,0 +1,40 @@
+package io.quarkiverse.langchain4j.deployment;
+
+import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;
+
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.model.Tokenizer;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkus.runtime.annotations.ConfigRoot;
+import io.smallrye.config.ConfigMapping;
+import io.smallrye.config.WithDefault;
+
+@ConfigRoot(phase = BUILD_TIME)
+@ConfigMapping(prefix = "quarkus.langchain4j.chat-memory")
+public interface ChatMemoryBuildConfig {
+
+ /**
+ * Configure the type of {@link ChatMemory} that will be used by default by the default {@link ChatMemoryProvider} bean.
+ *
+ * The extension provides a default bean that configures {@link ChatMemoryProvider} for use with AI services
+ * registered with {@link RegisterAiService}. This bean depends uses the {@code quarkus.langchain4j.chat-memory}
+ * configuration to set things up while also depending on the presence of a bean of type {@link ChatMemoryStore} (for which
+ * the extension also provides a default in the form of {@link InMemoryChatMemoryStore}).
+ *
+ * If {@code token-window} is used, then the application must also provide a bean of type {@link Tokenizer}.
+ *
+ * Users can choose to provide their own {@link ChatMemoryStore} bean or even their own {@link ChatMemoryProvider} bean
+ * if full control over the details is needed.
+ */
+ @WithDefault("MESSAGE_WINDOW")
+ Type type();
+
+ enum Type {
+ MESSAGE_WINDOW,
+ TOKEN_WINDOW
+ }
+
+}
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ChatMemoryProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ChatMemoryProcessor.java
new file mode 100644
index 000000000..05d56a82d
--- /dev/null
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ChatMemoryProcessor.java
@@ -0,0 +1,51 @@
+package io.quarkiverse.langchain4j.deployment;
+
+import java.util.function.Function;
+
+import jakarta.enterprise.context.ApplicationScoped;
+
+import org.jboss.jandex.ClassType;
+
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.model.Tokenizer;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+import io.quarkiverse.langchain4j.runtime.ChatMemoryRecorder;
+import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryConfig;
+import io.quarkus.arc.SyntheticCreationalContext;
+import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
+import io.quarkus.deployment.annotations.BuildProducer;
+import io.quarkus.deployment.annotations.BuildStep;
+import io.quarkus.deployment.annotations.ExecutionTime;
+import io.quarkus.deployment.annotations.Record;
+
+public class ChatMemoryProcessor {
+
+ @BuildStep
+ @Record(ExecutionTime.RUNTIME_INIT)
+ void setupBeans(ChatMemoryBuildConfig buildConfig, ChatMemoryConfig runtimeConfig,
+ ChatMemoryRecorder recorder,
+ BuildProducer syntheticBeanProducer) {
+
+ Function, ChatMemoryProvider> fun;
+
+ SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
+ .configure(ChatMemoryProvider.class)
+ .setRuntimeInit()
+ .addInjectionPoint(ClassType.create(ChatMemoryStore.class))
+ .scope(ApplicationScoped.class)
+ .defaultBean();
+
+ if (buildConfig.type() == ChatMemoryBuildConfig.Type.MESSAGE_WINDOW) {
+ fun = recorder.messageWindow(runtimeConfig);
+ } else if (buildConfig.type() == ChatMemoryBuildConfig.Type.TOKEN_WINDOW) {
+ configurator.addInjectionPoint(ClassType.create(Tokenizer.class));
+ fun = recorder.tokenWindow(runtimeConfig);
+ } else {
+ throw new IllegalStateException(
+ "Invalid configuration '" + buildConfig.type() + "' used in 'quarkus.langchain4j.chat-memory.type'");
+ }
+ configurator.createWith(fun);
+
+ syntheticBeanProducer.produce(configurator.done());
+ }
+}
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java
index 428e5c447..a00aee0c0 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java
@@ -46,10 +46,6 @@ public class Langchain4jDotNames {
static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatMemoryProviderSupplier.class);
- static final DotName BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
- RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class);
- static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
- RegisterAiService.NoChatMemoryProviderSupplier.class);
static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class);
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/ChatMemoryRemover.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/ChatMemoryRemover.java
new file mode 100644
index 000000000..6edf201d3
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/ChatMemoryRemover.java
@@ -0,0 +1,39 @@
+package io.quarkiverse.langchain4j;
+
+import java.util.List;
+
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;
+
+/**
+ * Allows the application to manually control when a {@link ChatMemory} should be removed from the underlying
+ * {@link ChatMemoryStore}.
+ */
+public final class ChatMemoryRemover {
+
+ private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0];
+
+ private ChatMemoryRemover() {
+ }
+
+ /**
+ * @param aiService The bean that implements the AI Service annotated with {@link RegisterAiService}
+ * @param memoryId The object used as memory IDs for which the corresponding {@link ChatMemory} should be removed
+ */
+ public static void remove(Object aiService, Object memoryId) {
+ if (aiService instanceof ChatMemoryRemovable r) {
+ r.remove(memoryId);
+ }
+ }
+
+ /**
+ * @param aiService The bean that implements the AI Service annotated with {@link RegisterAiService}
+ * @param memoryIds The objects used as memory IDs for which the corresponding {@link ChatMemory} should be removed
+ */
+ public static void remove(Object aiService, List memoryIds) {
+ if (aiService instanceof ChatMemoryRemovable r) {
+ r.remove(memoryIds.toArray(EMPTY_OBJECT_ARRAY));
+ }
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
index c283212d2..ef916d670 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
@@ -10,9 +10,12 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import io.quarkiverse.langchain4j.audit.AuditService;
/**
@@ -44,27 +47,29 @@
/**
* Tool classes to use. All tools are expected to be CDI beans.
- *
- * NOTE: when this is used, either a {@link ChatMemoryProvider} bean must be present in the application, or a custom
- * {@link Supplier} must be set.
*/
Class>[] tools() default {};
/**
- * Configures the way to obtain the {@link ChatMemoryProvider} to use.
- * By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}, but will fall back to not using
- * any memory if no such bean exists.
- * If an arbitrary {@link ChatMemoryProvider} instance is needed, a custom implementation of
- * {@link Supplier} needs to be provided.
+ * Configures the way to obtain the {@link ChatMemoryProvider}.
+ *
+ * Be default, Quarkus configures a {@link ChatMemoryProvider} bean that uses a {@link InMemoryChatMemoryStore} bean
+ * as the backing store. The default type for the actual {@link ChatMemory} is {@link MessageWindowChatMemory}
+ * and it is configured with the value of the {@code quarkus.langchain4j.chat-memory.memory-window.max-messages}
+ * configuration property (which default to 10) as a way of limiting the number of messages in each chat.
*
- * If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then
- * set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier}
+ * If the application provides its own {@link ChatMemoryProvider} bean, that takes precedence over what Quarkus provides as
+ * the default.
+ *
+ * If the application provides an implementation of {@link ChatMemoryStore}, then that is used instead of the default
+ * {@link InMemoryChatMemoryStore}.
+ *
+ * In the most advances case, an arbitrary {@link ChatMemoryProvider} can be used by having a custom
+ * {@code Supplier} configured in this property.
+ * {@link Supplier} needs to be provided.
*
- * NOTE: when {@link tools} is set, the default is changed to {@link BeanChatMemoryProviderSupplier} which means that a
- * bean a {@link ChatMemoryProvider} bean must be present. The alternative in this case is to set a custom
- * {@link Supplier}.
*/
- Class extends Supplier> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class;
+ Class extends Supplier> chatMemoryProviderSupplier() default BeanChatMemoryProviderSupplier.class;
/**
* Configures the way to obtain the {@link Retriever} to use (when using RAG).
@@ -98,8 +103,11 @@ public ChatLanguageModel get() {
}
/**
- * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. If the bean does
- * not exist, Quarkus will fail at build time.
+ * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean.
+ * Be default, Quarkus configures an {@link ChatMemoryProvider} by using an {@link InMemoryChatMemoryStore}
+ * as the backing store while using {@link MessageWindowChatMemory} with the value of
+ * configuration property {@code quarkus.langchain4j.chat-memory.memory-window.max-messages} (which default to 10)
+ * as a way of limiting the number of messages in each chat.
*/
final class BeanChatMemoryProviderSupplier implements Supplier {
@@ -109,29 +117,6 @@ public ChatMemoryProvider get() {
}
}
- /**
- * Marker that is used to tell Quarkus to use the {@link ChatMemoryProvider} that the user has configured as a CDI bean.
- * If no such bean exists, then no memory will be used.
- */
- final class BeanIfExistsChatMemoryProviderSupplier implements Supplier {
-
- @Override
- public ChatMemoryProvider get() {
- throw new UnsupportedOperationException("should never be called");
- }
- }
-
- /**
- * Marker class to indicate that no chat memory should be used
- */
- final class NoChatMemoryProviderSupplier implements Supplier {
-
- @Override
- public ChatMemoryProvider get() {
- throw new UnsupportedOperationException("should never be called");
- }
- }
-
/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean
*/
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java
deleted file mode 100644
index e309f27d8..000000000
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RemovableChatMemoryProvider.java
+++ /dev/null
@@ -1,13 +0,0 @@
-package io.quarkiverse.langchain4j;
-
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.ChatMemoryProvider;
-
-/**
- * Extends {@link ChatMemoryProvider} to allow for removing {@link ChatMemory}
- * when it is no longer needed.
- */
-public interface RemovableChatMemoryProvider extends ChatMemoryProvider {
-
- void remove(Object id);
-}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
index 05f6cc333..59d17f148 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
@@ -104,13 +104,6 @@ public T apply(SyntheticCreationalContext creationalContext) {
.equals(info.getChatMemoryProviderSupplierClassName())) {
quarkusAiServices.chatMemoryProvider(creationalContext.getInjectedReference(
ChatMemoryProvider.class));
- } else if (RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class.getName()
- .equals(info.getChatMemoryProviderSupplierClassName())) {
- Instance instance = creationalContext
- .getInjectedReference(CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL);
- if (instance.isResolvable()) {
- quarkusAiServices.chatMemoryProvider(instance.get());
- }
} else {
Supplier extends ChatMemoryProvider> supplier = (Supplier extends ChatMemoryProvider>) Thread
.currentThread().getContextClassLoader()
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ChatMemoryRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ChatMemoryRecorder.java
new file mode 100644
index 000000000..555146a1c
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ChatMemoryRecorder.java
@@ -0,0 +1,58 @@
+package io.quarkiverse.langchain4j.runtime;
+
+import java.util.function.Function;
+
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.memory.chat.MessageWindowChatMemory;
+import dev.langchain4j.memory.chat.TokenWindowChatMemory;
+import dev.langchain4j.model.Tokenizer;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryConfig;
+import io.quarkus.arc.SyntheticCreationalContext;
+import io.quarkus.runtime.annotations.Recorder;
+
+@Recorder
+public class ChatMemoryRecorder {
+
+ public Function, ChatMemoryProvider> messageWindow(ChatMemoryConfig config) {
+ return new Function<>() {
+ @Override
+ public ChatMemoryProvider apply(SyntheticCreationalContext context) {
+ ChatMemoryStore chatMemoryStore = context.getInjectedReference(ChatMemoryStore.class);
+ int maxMessages = config.memoryWindow().maxMessages();
+ return new ChatMemoryProvider() {
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return MessageWindowChatMemory.builder()
+ .maxMessages(maxMessages)
+ .id(memoryId)
+ .chatMemoryStore(chatMemoryStore)
+ .build();
+ }
+ };
+ }
+ };
+ }
+
+ public Function, ChatMemoryProvider> tokenWindow(ChatMemoryConfig config) {
+ return new Function<>() {
+ @Override
+ public ChatMemoryProvider apply(SyntheticCreationalContext context) {
+ ChatMemoryStore chatMemoryStore = context.getInjectedReference(ChatMemoryStore.class);
+ Tokenizer tokenizer = context.getInjectedReference(Tokenizer.class);
+ int maxTokens = config.tokenWindow().maxTokens();
+ return new ChatMemoryProvider() {
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return TokenWindowChatMemory.builder()
+ .maxTokens(maxTokens, tokenizer)
+ .id(memoryId)
+ .chatMemoryStore(chatMemoryStore)
+ .build();
+ }
+ };
+ }
+ };
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
index b343d3fb2..108ddd68e 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
@@ -115,7 +115,6 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[]
}
Object memoryId = memoryId(createInfo, methodArgs).orElse("default");
- context.usedMemoryIds.add(memoryId);
if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ChatMemoryConfig.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ChatMemoryConfig.java
new file mode 100644
index 000000000..64b80eda1
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ChatMemoryConfig.java
@@ -0,0 +1,54 @@
+package io.quarkiverse.langchain4j.runtime.aiservice;
+
+import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME;
+
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.memory.chat.MessageWindowChatMemory;
+import dev.langchain4j.memory.chat.TokenWindowChatMemory;
+import io.quarkus.runtime.annotations.ConfigGroup;
+import io.quarkus.runtime.annotations.ConfigRoot;
+import io.smallrye.config.ConfigMapping;
+import io.smallrye.config.WithDefault;
+
+@ConfigRoot(phase = RUN_TIME)
+@ConfigMapping(prefix = "quarkus.langchain4j.chat-memory")
+public interface ChatMemoryConfig {
+
+ /**
+ * Configures aspects of the {@link MessageWindowChatMemory} which is the default {@link dev.langchain4j.memory.ChatMemory}
+ * setup by the extension.
+ * This only has effect if {@code quarkus.langchain4j.chat-memory.type} has not been configured (or is configured to
+ * {@code memory-window}) and no bean of
+ * type {@link ChatMemoryProvider}
+ * is present in the application.
+ */
+ MemoryWindow memoryWindow();
+
+ /**
+ * Configures aspects of the {@link TokenWindowChatMemory} which is enabled if the
+ * {@code quarkus.langchain4j.chat-memory.type} configuration property
+ * is set to {@code token-window} and if no nd no bean of type {@link ChatMemoryProvider} is present in the application.
+ */
+ TokenWindow tokenWindow();
+
+ @ConfigGroup
+ interface MemoryWindow {
+
+ /**
+ * The maximum number of messages the configured {@link MessageWindowChatMemory} will hold
+ */
+ @WithDefault("10")
+ int maxMessages();
+ }
+
+ @ConfigGroup
+ interface TokenWindow {
+
+ /**
+ * The maximum number of tokens the configured {@link TokenWindowChatMemory} will hold
+ */
+ @WithDefault("1000")
+ int maxTokens();
+ }
+
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ChatMemoryRemovable.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ChatMemoryRemovable.java
new file mode 100644
index 000000000..d398db0bd
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/ChatMemoryRemovable.java
@@ -0,0 +1,9 @@
+package io.quarkiverse.langchain4j.runtime.aiservice;
+
+/**
+ * Interface implemented by each AiService that allows the removal of chat memories from an AiService
+ */
+public interface ChatMemoryRemovable {
+
+ void remove(Object... ids);
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/InMemoryChatMemoryStoreProducer.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/InMemoryChatMemoryStoreProducer.java
new file mode 100644
index 000000000..500b31a4d
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/InMemoryChatMemoryStoreProducer.java
@@ -0,0 +1,23 @@
+package io.quarkiverse.langchain4j.runtime.aiservice;
+
+import jakarta.enterprise.inject.Produces;
+import jakarta.inject.Singleton;
+
+import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkus.arc.DefaultBean;
+import io.quarkus.arc.Unremovable;
+
+/**
+ * Creates the default {@link InMemoryChatMemoryStore} store to be used by classes annotated with {@link RegisterAiService}
+ */
+public class InMemoryChatMemoryStoreProducer {
+
+ @Produces
+ @Singleton
+ @DefaultBean
+ @Unremovable
+ public InMemoryChatMemoryStore chatMemoryStore() {
+ return new InMemoryChatMemoryStore();
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java
index 2d6cb8319..82c85ece5 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java
@@ -1,19 +1,16 @@
package io.quarkiverse.langchain4j.runtime.aiservice;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BiConsumer;
+import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.service.AiServiceContext;
import io.quarkiverse.langchain4j.RegisterAiService;
-import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
import io.quarkiverse.langchain4j.audit.AuditService;
public class QuarkusAiServiceContext extends AiServiceContext {
public AuditService auditService;
- public Set usedMemoryIds = ConcurrentHashMap.newKeySet();
-
public QuarkusAiServiceContext(Class> aiServiceClass) {
super(aiServiceClass);
}
@@ -23,22 +20,30 @@ public QuarkusAiServiceContext(Class> aiServiceClass) {
* when the bean's scope is closed
*/
public void close() {
- removeChatMemories();
+ clearChatMemory();
}
- private void removeChatMemories() {
- if (usedMemoryIds.isEmpty()) {
- return;
- }
- RemovableChatMemoryProvider removableChatMemoryProvider = null;
- if (chatMemoryProvider instanceof RemovableChatMemoryProvider) {
- removableChatMemoryProvider = (RemovableChatMemoryProvider) chatMemoryProvider;
- }
- for (Object memoryId : usedMemoryIds) {
- if (removableChatMemoryProvider != null) {
- removableChatMemoryProvider.remove(memoryId);
+ private void clearChatMemory() {
+ chatMemories.forEach(new BiConsumer<>() {
+ @Override
+ public void accept(Object memoryId, ChatMemory chatMemory) {
+ chatMemory.clear();
+ }
+ });
+ chatMemories = null;
+ }
+
+ /**
+ * This is called by the {@code remove(Object... ids)} method of AiServices when a user manually requests removal of chat
+ * memories
+ * via {@link io.quarkiverse.langchain4j.ChatMemoryRemover}
+ */
+ public void removeChatMemoryIds(Object... ids) {
+ for (Object id : ids) {
+ ChatMemory chatMemory = chatMemories.remove(id);
+ if (chatMemory != null) {
+ chatMemory.clear();
}
- chatMemories.remove(memoryId);
}
}
}
diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java
deleted file mode 100644
index 8eb87ae96..000000000
--- a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java
+++ /dev/null
@@ -1,29 +0,0 @@
-package io.quarkiverse.langchain4j.samples;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-
-import jakarta.inject.Singleton;
-
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
-import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
-
-@Singleton
-public class ChatMemoryBean implements RemovableChatMemoryProvider {
-
- private final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(20)
- .id(memoryId)
- .build());
- }
-
- @Override
- public void remove(Object id) {
- memories.remove(id);
- }
-}
diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/CustomChatMemoryProvider.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/CustomChatMemoryProvider.java
new file mode 100644
index 000000000..b6998b80f
--- /dev/null
+++ b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/CustomChatMemoryProvider.java
@@ -0,0 +1,32 @@
+package io.quarkiverse.langchain4j.samples;
+
+import jakarta.inject.Singleton;
+
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+
+@Singleton
+public class CustomChatMemoryProvider implements ChatMemoryProvider {
+
+ private final ChatMemoryStore store;
+
+ public CustomChatMemoryProvider() {
+ this.store = createCustomStore();
+ }
+
+ private static ChatMemoryStore createCustomStore() {
+ // TODO: provide some kind of custom store
+ return null;
+ }
+
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return createCustomMemory(memoryId);
+ }
+
+ private static ChatMemory createCustomMemory(Object memoryId) {
+ // TODO: implement using memoryId and store
+ return null;
+ }
+}
diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/CustomProvider.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/CustomProvider.java
new file mode 100644
index 000000000..1543b2b49
--- /dev/null
+++ b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/CustomProvider.java
@@ -0,0 +1,28 @@
+package io.quarkiverse.langchain4j.samples;
+
+import java.util.function.Supplier;
+
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.memory.chat.MessageWindowChatMemory;
+import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
+
+public class CustomProvider implements Supplier {
+
+ private final InMemoryChatMemoryStore store = new InMemoryChatMemoryStore();
+
+ @Override
+ public ChatMemoryProvider get() {
+ return new ChatMemoryProvider() {
+
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return MessageWindowChatMemory.builder()
+ .maxMessages(20)
+ .id(memoryId)
+ .chatMemoryStore(store)
+ .build();
+ }
+ };
+ }
+}
diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MySmallMemoryProvider.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MySmallMemoryProvider.java
deleted file mode 100644
index a4afa028d..000000000
--- a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MySmallMemoryProvider.java
+++ /dev/null
@@ -1,32 +0,0 @@
-package io.quarkiverse.langchain4j.samples;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Supplier;
-
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.ChatMemoryProvider;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
-import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
-
-public class MySmallMemoryProvider implements Supplier {
- @Override
- public ChatMemoryProvider get() {
- return new RemovableChatMemoryProvider() {
- private final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(20)
- .id(memoryId)
- .build());
- }
-
- @Override
- public void remove(Object id) {
- memories.remove(id);
- }
- };
- }
-}
diff --git a/docs/modules/ROOT/pages/agent-and-tools.adoc b/docs/modules/ROOT/pages/agent-and-tools.adoc
index 6d5b4f99c..260e3abdb 100644
--- a/docs/modules/ROOT/pages/agent-and-tools.adoc
+++ b/docs/modules/ROOT/pages/agent-and-tools.adoc
@@ -124,7 +124,7 @@ public class AssistantWithToolsResource {
this.assistant = assistant;
}
- @GET // <6>
+ @GET // <3>
public String get(@RestQuery String message) {
return assistant.chat(message);
}
@@ -148,25 +148,7 @@ public class AssistantWithToolsResource {
}
}
- @Singleton // <2>
- public static class ChatMemoryBean implements RemovableChatMemoryProvider { // <3>
-
- private final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(20)
- .id(memoryId)
- .build());
- }
-
- public void remove(Object id) { // <4>
- memories.remove(id);
- }
- }
-
- @RegisterAiService(tools = Calculator.class) // <5>
+ @RegisterAiService(tools = Calculator.class) // <2>
public interface Assistant {
String chat(String userMessage);
@@ -174,11 +156,8 @@ public class AssistantWithToolsResource {
}
----
<1> Declare a CDI bean that provides three different tools
-<2> Declare a CDI bean for providing a simple in-memory message store
-<3> By making the bean implement `RemovableChatMemoryProvider`, the objects used as memory IDs are removed from memory when the service goes out of scope
-<4> The `remove` method is called automatically by Quarkus when an AiService goes out of scope in order to remove the memory objects used by said service
-<5> Register an AiService that responds to a user's request and has access to the calculator tools. This service is also able to keep track of the session's messages using the CDI message store declared above.
-<6> Declare an HTTP endpoint that retrieves the user's question via a query parameter and simply responds with chatbot's response
+<2> Register an AiService that responds to a user's request and has access to the calculator tools. This service is also able to keep track of the session's messages using the CDI message store declared above.
+<3> Declare an HTTP endpoint that retrieves the user's question via a query parameter and simply responds with chatbot's response
Now, if we ask the chatbot `What is the square root of the sum of the numbers of letters in the words "hello" and "world"` via:
diff --git a/docs/modules/ROOT/pages/ai-services.adoc b/docs/modules/ROOT/pages/ai-services.adoc
index 4cca74256..849877d88 100644
--- a/docs/modules/ROOT/pages/ai-services.adoc
+++ b/docs/modules/ROOT/pages/ai-services.adoc
@@ -27,10 +27,11 @@ Once registered, you can inject the _AI Service_ into your application:
@Inject MyAiService service;
----
+[#scope]
[IMPORTANT]
====
The beans created by `@RegisterAiService` are `@RequestScoped` by default. The reason for this is that it enables removing chat <> objects.
-This is a good default when a service is used during when handling an HTTP request, but it's inappropriate in CLIs or in WebSockets (currently, but may change in the future).
+This is a good default when a service is used during when handling an HTTP request, but it's inappropriate in CLIs or in WebSockets (WebSocket support is expected to improve in the near future).
For example when using a service in a CLI, it makes sense to have the service be `@ApplicationScoped` and the extension allows this simply if the service is annotated with `@ApplicationScoped`.
====
@@ -175,24 +176,40 @@ public class MyChatModelSupplier implements Supplier {
As LLMs are stateless, the memory — comprising the interaction context — must be exchanged each time. To prevent storing excessive messages, it's crucial to evict older messages.
-The `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the memory provider. The default value of this annotation is `RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class`
+The `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the `dev.langchain4j.memory.chat.ChatMemoryProvider`. The default value of this annotation is `RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class`
which means that the `AiService` will use whatever `ChatMemoryProvider` bean is configured by the application, while falling back to no memory if no such bean exists.
-An example of such a bean is:
+
+The extension provides a default implementation of `ChatMemoryProvider` which does two things:
+
+* It uses whatever bean `dev.langchain4j.store.memory.chat.ChatMemoryStore` bean is configured, as the backing store. The default implementation is `dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore`
+** If the application provides its own `ChatMemoryStore` bean, that will be used instead of the default `InMemoryChatMemoryStore`,
+* It leverages the available configuration options under `quarkus.langchain4j.chat-memory` to construct the `ChatMemoryProvider`.
+** The default configuration values result in the usage of `dev.langchain4j.memory.chat.MessageWindowChatMemory` with a window size of ten
+** By setting `quarkus.langchain4j.chat-memory.type=token-window`, a `dev.langchain4j.memory.chat.TokenWindowChatMemory` will be used. Note that this requires the presence of a `dev.langchain4j.model.Tokenizer` bean.
+
+[IMPORTANT]
+====
+The topic of `ChatMemory` cleanup is of paramount importance in order to avoid having the application terminate with `out of memory` errors. For this reason, the extension automatically removes all the `ChatMemory` objects from
+the underlying `ChatMemoryStore` when the AI Service goes out of scope (recall from our discussion about <> that such bean are `@RequestScoped` be default).
+
+However, in cases where more fine-grained control is needed (which is the case when the bean is declared as `@Singleton` or `@ApplicationScoped`) then `io.quarkiverse.langchain4j.ChatMemoryRemover` should be used to manually remove elements.
+====
+
+=== Advanced usage
+
+Although the extension's default `ChatMemoryProvider` is very configurable making unnecessary in most cases to resort to a custom implementation, such a capability is possible. Here is a possible example:
[source,java]
----
-include::{examples-dir}/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java[]
+include::{examples-dir}/io/quarkiverse/langchain4j/samples/CustomChatMemoryProvider.java[]
----
-Notice that the messages are deleted when the scope terminates (as it will call the `close` method).
-
-NOTE: It is recommended to have your chat memory beans implement `RemovableChatMemoryProvider` because the objects used as memory IDs are removed from the memory when the service goes out of scope.
-
-Users can provide their own custom `ChatMemoryProvider` for use in the AiService by implementing `Supplier`, such as:
+If for some reason different AI services need to have a different `ChatMemoryProvider` (i.e. not use the globally available bean), this is possible by configuring the `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation and implementing as custom provider.
+Here is a possible example:
[source,java]
----
-include::{examples-dir}/io/quarkiverse/langchain4j/samples/MySmallMemoryProvider.java[]
+include::{examples-dir}/io/quarkiverse/langchain4j/samples/CustomProvider.java[]
----
and configuring the AiService as so:
diff --git a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java
deleted file mode 100644
index 81c2759d0..000000000
--- a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java
+++ /dev/null
@@ -1,52 +0,0 @@
-package io.quarkiverse.langchain4j.openai.test;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.junit.jupiter.api.Assertions.fail;
-
-import jakarta.enterprise.inject.spi.DeploymentException;
-import jakarta.inject.Inject;
-import jakarta.inject.Singleton;
-
-import org.jboss.shrinkwrap.api.ShrinkWrap;
-import org.jboss.shrinkwrap.api.spec.JavaArchive;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
-
-import dev.langchain4j.agent.tool.Tool;
-import io.quarkiverse.langchain4j.RegisterAiService;
-import io.quarkus.test.QuarkusUnitTest;
-
-public class AiServiceWithToolsAndNoMemoryTest {
-
- @RegisterExtension
- static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
- .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses())
- .assertException(t -> {
- assertThat(t)
- .isInstanceOf(DeploymentException.class)
- .hasMessageContaining("ChatMemoryProvider");
- });
-
- @RegisterAiService(tools = CustomTool.class)
- interface Assistant {
-
- String chat(String input);
- }
-
- @Singleton
- static class CustomTool {
-
- @Tool
- void doSomething() {
-
- }
- }
-
- @Inject
- Assistant assistant;
-
- @Test
- void test() {
- fail("Should not be called");
- }
-}
diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java
index b65f791af..5b99cf5d7 100644
--- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java
+++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java
@@ -36,10 +36,7 @@
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
-import dev.langchain4j.memory.chat.ChatMemoryProvider;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.output.Response;
-import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
@@ -65,18 +62,6 @@ public class AuditingServiceTest {
static ObjectMapper mapper;
- public static class ChatMemoryProviderProducer {
-
- @Singleton
- ChatMemoryProvider chatMemory() {
- return memoryId -> MessageWindowChatMemory.builder()
- .id(memoryId)
- .maxMessages(10)
- .chatMemoryStore(new InMemoryChatMemoryStore())
- .build();
- }
- }
-
@BeforeAll
static void beforeAll() {
wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT));
diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/CustomChatMemoryProviderTest.java
similarity index 76%
rename from openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java
rename to openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/CustomChatMemoryProviderTest.java
index e8be0abd3..e0d8db70e 100644
--- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java
+++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/CustomChatMemoryProviderTest.java
@@ -1,29 +1,23 @@
package org.acme.examples.aiservices;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
-import static dev.langchain4j.data.message.ChatMessageDeserializer.messagesFromJson;
-import static dev.langchain4j.data.message.ChatMessageSerializer.messagesToJson;
import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.data.message.ChatMessageType.USER;
import static org.acme.examples.aiservices.MessageAssertUtils.assertMultipleRequestMessage;
import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.tuple;
-import static org.assertj.core.api.InstanceOfAssertFactories.list;
-import static org.assertj.core.api.InstanceOfAssertFactories.map;
import java.io.IOException;
-import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.concurrent.CopyOnWriteArrayList;
+import jakarta.enterprise.inject.Produces;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
-import org.assertj.core.api.InstanceOfAssertFactory;
-import org.assertj.core.api.ListAssert;
-import org.assertj.core.api.MapAssert;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.AfterAll;
@@ -39,6 +33,7 @@
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.service.MemoryId;
@@ -49,8 +44,10 @@
import io.quarkus.arc.Arc;
import io.quarkus.test.QuarkusUnitTest;
-public class BeanDeclarativeAiServicesTest {
+public class CustomChatMemoryProviderTest {
+ public static final int FIRST_MEMORY_ID = 1;
+ public static final int SECOND_MEMORY_ID = 2;
private static final int WIREMOCK_PORT = 8089;
@RegisterExtension
@@ -61,18 +58,11 @@ public class BeanDeclarativeAiServicesTest {
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1");
private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() {
};
- private static final InstanceOfAssertFactory> MAP_STRING_STRING = map(String.class,
- String.class);
- private static final InstanceOfAssertFactory> LIST_MAP = list(Map.class);
static WireMockServer wireMockServer;
static ObjectMapper mapper;
- private static MessageWindowChatMemory createChatMemory() {
- return MessageWindowChatMemory.withMaxMessages(10);
- }
-
@BeforeAll
static void beforeAll() {
wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT));
@@ -92,47 +82,34 @@ void setup() {
wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub());
}
- public static class ChatMemoryProviderProducer {
+ @RegisterAiService
+ @Singleton
+ interface ChatWithSeparateMemoryForEachUser {
- @Singleton
- ChatMemoryProvider chatMemory(ChatMemoryStore store) {
- return memoryId -> MessageWindowChatMemory.builder()
- .id(memoryId)
- .maxMessages(10)
- .chatMemoryStore(store)
- .build();
- }
+ String chat(@MemoryId int memoryId, @UserMessage String userMessage);
}
- @Singleton
- public static class CustomChatMemoryStore implements ChatMemoryStore {
-
- // emulating persistent storage
- private final Map* memoryId */ Object, String> persistentStorage = new HashMap<>();
+ public static class CustomChatMemoryProviderProducer {
- @Override
- public List getMessages(Object memoryId) {
- return messagesFromJson(persistentStorage.get(memoryId));
- }
+ static final List MEMORY_IDS = new CopyOnWriteArrayList<>();
- @Override
- public void updateMessages(Object memoryId, List messages) {
- persistentStorage.put(memoryId, messagesToJson(messages));
- }
-
- @Override
- public void deleteMessages(Object memoryId) {
- persistentStorage.remove(memoryId);
+ @Singleton
+ @Produces
+ public ChatMemoryProvider chatMemoryProvider(ChatMemoryStore store) {
+ return new ChatMemoryProvider() {
+ @Override
+ public ChatMemory get(Object memoryId) {
+ MEMORY_IDS.add(memoryId);
+ return MessageWindowChatMemory.builder()
+ .maxMessages(10)
+ .id(memoryId)
+ .chatMemoryStore(store)
+ .build();
+ }
+ };
}
}
- @RegisterAiService
- @Singleton
- interface ChatWithSeparateMemoryForEachUser {
-
- String chat(@MemoryId int memoryId, @UserMessage String userMessage);
- }
-
@Inject
ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser;
@@ -141,14 +118,11 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
ChatMemoryStore store = Arc.container().instance(ChatMemoryStore.class).get();
- int firstMemoryId = 1;
- int secondMemoryId = 2;
-
/* **** First request for user 1 **** */
String firstMessageFromFirstUser = "Hello, my name is Klaus";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Nice to meet you Klaus"));
- String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, firstMessageFromFirstUser);
+ String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID, firstMessageFromFirstUser);
// assert response
assertThat(firstAiResponseToFirstUser).isEqualTo("Nice to meet you Klaus");
@@ -157,7 +131,7 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
assertSingleRequestMessage(getRequestAsMap(), firstMessageFromFirstUser);
// assert chat memory
- assertThat(store.getMessages(firstMemoryId)).hasSize(2)
+ assertThat(store.getMessages(FIRST_MEMORY_ID)).hasSize(2)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser));
@@ -167,7 +141,8 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
String firstMessageFromSecondUser = "Hello, my name is Francine";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Nice to meet you Francine"));
- String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId, firstMessageFromSecondUser);
+ String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID,
+ firstMessageFromSecondUser);
// assert response
assertThat(firstAiResponseToSecondUser).isEqualTo("Nice to meet you Francine");
@@ -176,7 +151,7 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
assertSingleRequestMessage(getRequestAsMap(), firstMessageFromSecondUser);
// assert chat memory
- assertThat(store.getMessages(secondMemoryId)).hasSize(2)
+ assertThat(store.getMessages(SECOND_MEMORY_ID)).hasSize(2)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser));
@@ -186,7 +161,8 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
String secondsMessageFromFirstUser = "What is my name?";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Your name is Klaus"));
- String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, secondsMessageFromFirstUser);
+ String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID,
+ secondsMessageFromFirstUser);
// assert response
assertThat(secondAiMessageToFirstUser).contains("Klaus");
@@ -199,7 +175,7 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
new MessageAssertUtils.MessageContent("user", secondsMessageFromFirstUser)));
// assert chat memory
- assertThat(store.getMessages(firstMemoryId)).hasSize(4)
+ assertThat(store.getMessages(FIRST_MEMORY_ID)).hasSize(4)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser),
tuple(USER, secondsMessageFromFirstUser), tuple(AI, secondAiMessageToFirstUser));
@@ -210,7 +186,7 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
String secondsMessageFromSecondUser = "What is my name?";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Your name is Francine"));
- String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId,
+ String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID,
secondsMessageFromSecondUser);
// assert response
@@ -224,10 +200,13 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
new MessageAssertUtils.MessageContent("user", secondsMessageFromSecondUser)));
// assert chat memory
- assertThat(store.getMessages(secondMemoryId)).hasSize(4)
+ assertThat(store.getMessages(SECOND_MEMORY_ID)).hasSize(4)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser),
tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser));
+
+ // assert that the custom ChatMemoryProvider was used
+ assertThat(CustomChatMemoryProviderProducer.MEMORY_IDS).containsExactly(FIRST_MEMORY_ID, SECOND_MEMORY_ID);
}
private Map getRequestAsMap() throws IOException {
diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/CustomChatMemoryStoreTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/CustomChatMemoryStoreTest.java
new file mode 100644
index 000000000..01e9b9181
--- /dev/null
+++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/CustomChatMemoryStoreTest.java
@@ -0,0 +1,262 @@
+package org.acme.examples.aiservices;
+
+import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options;
+import static dev.langchain4j.data.message.ChatMessageType.AI;
+import static dev.langchain4j.data.message.ChatMessageType.USER;
+import static org.acme.examples.aiservices.MessageAssertUtils.assertMultipleRequestMessage;
+import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.tuple;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import jakarta.enterprise.inject.Produces;
+import jakarta.inject.Inject;
+import jakarta.inject.Singleton;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.github.tomakehurst.wiremock.WireMockServer;
+import com.github.tomakehurst.wiremock.stubbing.ServeEvent;
+import com.github.tomakehurst.wiremock.verification.LoggedRequest;
+
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.service.MemoryId;
+import dev.langchain4j.service.UserMessage;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
+import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
+import io.quarkiverse.langchain4j.ChatMemoryRemover;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkiverse.langchain4j.openai.test.WiremockUtils;
+import io.quarkus.test.QuarkusUnitTest;
+
+public class CustomChatMemoryStoreTest {
+
+ public static final int FIRST_MEMORY_ID = 1;
+ public static final int SECOND_MEMORY_ID = 2;
+ private static final int WIREMOCK_PORT = 8089;
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(
+ () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class))
+ .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever")
+ .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1");
+ private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() {
+ };
+
+ static WireMockServer wireMockServer;
+
+ static ObjectMapper mapper;
+
+ @BeforeAll
+ static void beforeAll() {
+ wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT));
+ wireMockServer.start();
+
+ mapper = new ObjectMapper();
+ }
+
+ @AfterAll
+ static void afterAll() {
+ wireMockServer.stop();
+ }
+
+ @BeforeEach
+ void setup() {
+ wireMockServer.resetAll();
+ wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub());
+ }
+
+ @RegisterAiService
+ @Singleton
+ interface ChatWithSeparateMemoryForEachUser {
+
+ String chat(@MemoryId int memoryId, @UserMessage String userMessage);
+ }
+
+ public static class CustomChatMemoryStore extends InMemoryChatMemoryStore {
+
+ static AtomicInteger GET_MESSAGES_COUNT = new AtomicInteger();
+ static AtomicInteger UPDATE_MESSAGES_COUNT = new AtomicInteger();
+ static AtomicInteger DELETE_MESSAGES_COUNT = new AtomicInteger();
+
+ @Override
+ public List getMessages(Object memoryId) {
+ GET_MESSAGES_COUNT.incrementAndGet();
+ return super.getMessages(memoryId);
+ }
+
+ @Override
+ public void updateMessages(Object memoryId, List messages) {
+ UPDATE_MESSAGES_COUNT.incrementAndGet();
+ super.updateMessages(memoryId, messages);
+ }
+
+ @Override
+ public void deleteMessages(Object memoryId) {
+ DELETE_MESSAGES_COUNT.incrementAndGet();
+ super.deleteMessages(memoryId);
+ }
+ }
+
+ public static class CustomChatMemoryStoreProducer {
+
+ @Singleton
+ @Produces
+ public ChatMemoryStore customStore() {
+ return new CustomChatMemoryStore();
+ }
+ }
+
+ @Inject
+ ChatMemoryStore chatMemoryStore;
+
+ @Inject
+ ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser;
+
+ @Test
+ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException {
+ // assert the bean type is correct
+ assertThat(chatMemoryStore).isInstanceOf(CustomChatMemoryStore.class);
+
+ /* **** First request for user 1 **** */
+ String firstMessageFromFirstUser = "Hello, my name is Klaus";
+ wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
+ "Nice to meet you Klaus"));
+ String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID, firstMessageFromFirstUser);
+
+ // assert response
+ assertThat(firstAiResponseToFirstUser).isEqualTo("Nice to meet you Klaus");
+
+ // assert request
+ assertSingleRequestMessage(getRequestAsMap(), firstMessageFromFirstUser);
+
+ // assert chat memory
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).hasSize(2)
+ .extracting(ChatMessage::type, ChatMessage::text)
+ .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser));
+
+ /* **** First request for user 2 **** */
+ wireMockServer.resetRequests();
+
+ String firstMessageFromSecondUser = "Hello, my name is Francine";
+ wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
+ "Nice to meet you Francine"));
+ String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID,
+ firstMessageFromSecondUser);
+
+ // assert response
+ assertThat(firstAiResponseToSecondUser).isEqualTo("Nice to meet you Francine");
+
+ // assert request
+ assertSingleRequestMessage(getRequestAsMap(), firstMessageFromSecondUser);
+
+ // assert chat memory
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).hasSize(2)
+ .extracting(ChatMessage::type, ChatMessage::text)
+ .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser));
+
+ /* **** Second request for user 1 **** */
+ wireMockServer.resetRequests();
+
+ String secondsMessageFromFirstUser = "What is my name?";
+ wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
+ "Your name is Klaus"));
+ String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID,
+ secondsMessageFromFirstUser);
+
+ // assert response
+ assertThat(secondAiMessageToFirstUser).contains("Klaus");
+
+ // assert request
+ assertMultipleRequestMessage(getRequestAsMap(),
+ List.of(
+ new MessageAssertUtils.MessageContent("user", firstMessageFromFirstUser),
+ new MessageAssertUtils.MessageContent("assistant", firstAiResponseToFirstUser),
+ new MessageAssertUtils.MessageContent("user", secondsMessageFromFirstUser)));
+
+ // assert chat memory
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).hasSize(4)
+ .extracting(ChatMessage::type, ChatMessage::text)
+ .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser),
+ tuple(USER, secondsMessageFromFirstUser), tuple(AI, secondAiMessageToFirstUser));
+
+ /* **** Second request for user 2 **** */
+ wireMockServer.resetRequests();
+
+ String secondsMessageFromSecondUser = "What is my name?";
+ wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
+ "Your name is Francine"));
+ String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID,
+ secondsMessageFromSecondUser);
+
+ // assert response
+ assertThat(secondAiMessageToSecondUser).contains("Francine");
+
+ // assert request
+ assertMultipleRequestMessage(getRequestAsMap(),
+ List.of(
+ new MessageAssertUtils.MessageContent("user", firstMessageFromSecondUser),
+ new MessageAssertUtils.MessageContent("assistant", firstAiResponseToSecondUser),
+ new MessageAssertUtils.MessageContent("user", secondsMessageFromSecondUser)));
+
+ // assert chat memory
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).hasSize(4)
+ .extracting(ChatMessage::type, ChatMessage::text)
+ .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser),
+ tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser));
+
+ // assert out chat memory is used
+ assertThat(CustomChatMemoryStore.GET_MESSAGES_COUNT).hasPositiveValue();
+ assertThat(CustomChatMemoryStore.UPDATE_MESSAGES_COUNT).hasPositiveValue();
+
+ // assert delete has not been called because the tool is singleton
+ assertThat(CustomChatMemoryStore.DELETE_MESSAGES_COUNT).hasValue(0);
+
+ // remove the first entry
+ ChatMemoryRemover.remove(chatWithSeparateMemoryForEachUser, FIRST_MEMORY_ID);
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).isEmpty();
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).isNotEmpty();
+
+ // remove the second entry
+ ChatMemoryRemover.remove(chatWithSeparateMemoryForEachUser, SECOND_MEMORY_ID);
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).isEmpty();
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).isEmpty();
+
+ // now assert that our store was used for delete
+ assertThat(CustomChatMemoryStore.DELETE_MESSAGES_COUNT).hasValue(2);
+ }
+
+ private Map getRequestAsMap() throws IOException {
+ return getRequestAsMap(getRequestBody());
+ }
+
+ private Map getRequestAsMap(byte[] body) throws IOException {
+ return mapper.readValue(body, MAP_TYPE_REF);
+ }
+
+ private byte[] getRequestBody() {
+ assertThat(wireMockServer.getAllServeEvents()).hasSize(1);
+ ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test
+ return getRequestBody(serveEvent);
+ }
+
+ private byte[] getRequestBody(ServeEvent serveEvent) {
+ LoggedRequest request = serveEvent.getRequest();
+ assertThat(request.getBody()).isNotEmpty();
+ return request.getBody();
+ }
+}
diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java
index 8e4a218d1..8a25b94c7 100644
--- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java
+++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java
@@ -10,8 +10,6 @@
import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.tuple;
-import static org.assertj.core.api.InstanceOfAssertFactories.list;
-import static org.assertj.core.api.InstanceOfAssertFactories.map;
import java.io.IOException;
import java.util.HashMap;
@@ -23,9 +21,6 @@
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
-import org.assertj.core.api.InstanceOfAssertFactory;
-import org.assertj.core.api.ListAssert;
-import org.assertj.core.api.MapAssert;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.AfterAll;
@@ -65,18 +60,11 @@ public class DeclarativeAiServicesTest {
.overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1");
private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() {
};
- private static final InstanceOfAssertFactory> MAP_STRING_STRING = map(String.class,
- String.class);
- private static final InstanceOfAssertFactory> LIST_MAP = list(Map.class);
static WireMockServer wireMockServer;
static ObjectMapper mapper;
- private static MessageWindowChatMemory createChatMemory() {
- return MessageWindowChatMemory.withMaxMessages(10);
- }
-
@BeforeAll
static void beforeAll() {
wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT));
@@ -96,7 +84,7 @@ void setup() {
wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub());
}
- @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
+ @RegisterAiService
interface Assistant {
String chat(String message);
@@ -120,7 +108,7 @@ enum Sentiment {
NEGATIVE
}
- @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
+ @RegisterAiService
interface SentimentAnalyzer {
@UserMessage("Analyze sentiment of {it}")
diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java
index f586c8401..ec720e27c 100644
--- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java
+++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/RemovableChatMemoryTest.java
@@ -12,9 +12,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.concurrent.ConcurrentHashMap;
-import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import org.jboss.shrinkwrap.api.ShrinkWrap;
@@ -32,12 +30,11 @@
import com.github.tomakehurst.wiremock.verification.LoggedRequest;
import dev.langchain4j.data.message.ChatMessage;
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
+import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
+import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.quarkiverse.langchain4j.RegisterAiService;
-import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
import io.quarkiverse.langchain4j.openai.test.WiremockUtils;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
@@ -45,6 +42,8 @@
public class RemovableChatMemoryTest {
+ public static final int FIRST_MEMORY_ID = 1;
+ public static final int SECOND_MEMORY_ID = 2;
private static final int WIREMOCK_PORT = 8089;
@RegisterExtension
@@ -79,31 +78,15 @@ void setup() {
wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub());
}
- @ApplicationScoped
- public static class ChatMemoryBean implements RemovableChatMemoryProvider {
-
- static final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(20)
- .id(memoryId)
- .build());
- }
-
- @Override
- public void remove(Object id) {
- memories.remove(id);
- }
- }
-
@RegisterAiService
interface ChatWithSeparateMemoryForEachUser {
String chat(@MemoryId int memoryId, @UserMessage String userMessage);
}
+ @Inject
+ ChatMemoryStore chatMemoryStore;
+
@Inject
ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser;
@@ -113,27 +96,7 @@ void should_keep_separate_chat_memory_for_each_user_in_store() throws IOExceptio
ManagedContext requestContext = Arc.container().requestContext();
// add a dummy entry that should not affect the chat in any way
- ChatMemoryBean.memories.put("DUMMY", new ChatMemory() {
- @Override
- public Object id() {
- return null;
- }
-
- @Override
- public void add(ChatMessage message) {
-
- }
-
- @Override
- public List messages() {
- return null;
- }
-
- @Override
- public void clear() {
-
- }
- });
+ chatMemoryStore.updateMessages("DUMMY", List.of(new SystemMessage("dummy")));
try {
requestContext.activate();
@@ -143,18 +106,18 @@ public void clear() {
}
// since the request context was closed, we should now only have the initial dummy entry
- assertThat(ChatMemoryBean.memories).containsOnlyKeys("DUMMY");
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).isEmpty();
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).isEmpty();
+ assertThat(chatMemoryStore.getMessages("DUMMY")).hasSize(1);
}
private void testInRequestContext() throws IOException {
- int firstMemoryId = 1;
- int secondMemoryId = 2;
/* **** First request for user 1 **** */
String firstMessageFromFirstUser = "Hello, my name is Klaus";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Nice to meet you Klaus"));
- String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, firstMessageFromFirstUser);
+ String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID, firstMessageFromFirstUser);
// assert response
assertThat(firstAiResponseToFirstUser).isEqualTo("Nice to meet you Klaus");
@@ -163,7 +126,7 @@ private void testInRequestContext() throws IOException {
assertSingleRequestMessage(getRequestAsMap(), firstMessageFromFirstUser);
// assert chat memory
- assertThat(ChatMemoryBean.memories.get(firstMemoryId).messages()).hasSize(2)
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).hasSize(2)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser));
@@ -173,7 +136,8 @@ private void testInRequestContext() throws IOException {
String firstMessageFromSecondUser = "Hello, my name is Francine";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Nice to meet you Francine"));
- String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId, firstMessageFromSecondUser);
+ String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID,
+ firstMessageFromSecondUser);
// assert response
assertThat(firstAiResponseToSecondUser).isEqualTo("Nice to meet you Francine");
@@ -182,7 +146,7 @@ private void testInRequestContext() throws IOException {
assertSingleRequestMessage(getRequestAsMap(), firstMessageFromSecondUser);
// assert chat memory
- assertThat(ChatMemoryBean.memories.get(secondMemoryId).messages()).hasSize(2)
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).hasSize(2)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser));
@@ -192,7 +156,8 @@ private void testInRequestContext() throws IOException {
String secondsMessageFromFirstUser = "What is my name?";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Your name is Klaus"));
- String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, secondsMessageFromFirstUser);
+ String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(FIRST_MEMORY_ID,
+ secondsMessageFromFirstUser);
// assert response
assertThat(secondAiMessageToFirstUser).contains("Klaus");
@@ -205,7 +170,7 @@ private void testInRequestContext() throws IOException {
new MessageAssertUtils.MessageContent("user", secondsMessageFromFirstUser)));
// assert chat memory
- assertThat(ChatMemoryBean.memories.get(firstMemoryId).messages()).hasSize(4)
+ assertThat(chatMemoryStore.getMessages(FIRST_MEMORY_ID)).hasSize(4)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser),
tuple(USER, secondsMessageFromFirstUser), tuple(AI, secondAiMessageToFirstUser));
@@ -216,7 +181,7 @@ private void testInRequestContext() throws IOException {
String secondsMessageFromSecondUser = "What is my name?";
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(),
"Your name is Francine"));
- String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId,
+ String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(SECOND_MEMORY_ID,
secondsMessageFromSecondUser);
// assert response
@@ -230,7 +195,7 @@ private void testInRequestContext() throws IOException {
new MessageAssertUtils.MessageContent("user", secondsMessageFromSecondUser)));
// assert chat memory
- assertThat(ChatMemoryBean.memories.get(secondMemoryId).messages()).hasSize(4)
+ assertThat(chatMemoryStore.getMessages(SECOND_MEMORY_ID)).hasSize(4)
.extracting(ChatMessage::type, ChatMessage::text)
.containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser),
tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser));
diff --git a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java
index d692238e5..8cddc0a72 100644
--- a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java
+++ b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java
@@ -8,6 +8,8 @@
import org.eclipse.microprofile.context.ManagedExecutor;
+import io.quarkiverse.langchain4j.ChatMemoryRemover;
+
@ServerEndpoint("/chatbot")
public class ChatBotWebSocket {
@@ -17,9 +19,6 @@ public class ChatBotWebSocket {
@Inject
ManagedExecutor managedExecutor;
- @Inject
- ChatMemoryBean chatMemoryBean;
-
@OnOpen
public void onOpen(Session session) {
managedExecutor.execute(() -> {
@@ -34,7 +33,7 @@ public void onOpen(Session session) {
@OnClose
void onClose(Session session) {
- chatMemoryBean.clear(session);
+ ChatMemoryRemover.remove(bot, session);
}
@OnMessage
diff --git a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatMemoryBean.java b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatMemoryBean.java
deleted file mode 100644
index 1f353a25c..000000000
--- a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatMemoryBean.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package io.quarkiverse.langchain4j.sample.chatbot;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-
-import jakarta.enterprise.context.ApplicationScoped;
-
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.ChatMemoryProvider;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
-
-@ApplicationScoped
-public class ChatMemoryBean implements ChatMemoryProvider {
-
- private final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(20)
- .id(memoryId)
- .build());
- }
-
- public void clear(Object session) {
- memories.remove(session);
- }
-}
diff --git a/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java b/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java
index 0f5d6268a..08e7c3c4b 100644
--- a/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java
+++ b/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatBotWebSocket.java
@@ -11,6 +11,8 @@
import org.eclipse.microprofile.context.ManagedExecutor;
+import io.quarkiverse.langchain4j.ChatMemoryRemover;
+
@ServerEndpoint("/chatbot")
public class ChatBotWebSocket {
@@ -20,9 +22,6 @@ public class ChatBotWebSocket {
@Inject
ManagedExecutor managedExecutor;
- @Inject
- ChatMemoryBean chatMemoryBean;
-
@OnOpen
public void onOpen(Session session) {
managedExecutor.execute(() -> {
@@ -37,7 +36,7 @@ public void onOpen(Session session) {
@OnClose
void onClose(Session session) {
- chatMemoryBean.clear(session);
+ ChatMemoryRemover.remove(bot, session);
}
@OnMessage
diff --git a/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatMemoryBean.java b/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatMemoryBean.java
deleted file mode 100644
index 079dc84ce..000000000
--- a/samples/csv-chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/ChatMemoryBean.java
+++ /dev/null
@@ -1,28 +0,0 @@
-package io.quarkiverse.langchain4j.sample.chatbot;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-
-import jakarta.enterprise.context.ApplicationScoped;
-
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.ChatMemoryProvider;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
-
-@ApplicationScoped
-public class ChatMemoryBean implements ChatMemoryProvider {
-
- private final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(10)
- .id(memoryId)
- .build());
- }
-
- public void clear(Object session) {
- memories.remove(session);
- }
-}
diff --git a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java
deleted file mode 100644
index 4ab1056c0..000000000
--- a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/ChatMemoryBean.java
+++ /dev/null
@@ -1,29 +0,0 @@
-package io.quarkiverse.langchain4j.sample;
-
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-
-import jakarta.inject.Singleton;
-
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.MessageWindowChatMemory;
-import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
-
-@Singleton
-public class ChatMemoryBean implements RemovableChatMemoryProvider {
-
- private final Map memories = new ConcurrentHashMap<>();
-
- @Override
- public ChatMemory get(Object memoryId) {
- return memories.computeIfAbsent(memoryId, id -> MessageWindowChatMemory.builder()
- .maxMessages(20)
- .id(memoryId)
- .build());
- }
-
- @Override
- public void remove(Object id) {
- memories.remove(id);
- }
-}