Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul how chat memory is configured #112

Merged
merged 1 commit into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
Expand Down Expand Up @@ -106,6 +107,9 @@ public class AiServicesProcessor {

private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_CLOSE = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "close", void.class);

private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_REMOVE_CHAT_MEMORY_IDS = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "removeChatMemoryIds", void.class, Object[].class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

@BuildStep
Expand Down Expand Up @@ -184,16 +188,11 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}

// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
DotName chatMemoryProviderSupplierClassDotName = toolDotNames.isEmpty()
? Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER
: Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
if (chatMemoryProviderSupplierClassDotName.equals(
Langchain4jDotNames.NO_CHAT_MEMORY_PROVIDER_SUPPLIER)) {
chatMemoryProviderSupplierClassDotName = null;
} else if (!chatMemoryProviderSupplierClassDotName
if (!chatMemoryProviderSupplierClassDotName
.equals(Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER)) {
validateSupplierAndRegisterForReflection(chatMemoryProviderSupplierClassDotName, index,
reflectiveClassProducer);
Expand Down Expand Up @@ -313,11 +312,6 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER));
needsChatMemoryProviderBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
.equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
needsChatMemoryProviderBean = true;
}

if (Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER.toString().equals(retrieverSupplierClassName)) {
Expand Down Expand Up @@ -472,7 +466,7 @@ public void handleAiServices(AiServicesRecorder recorder,
ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(classOutput)
.className(implClassName)
.interfaces(ifaceName);
.interfaces(ifaceName, ChatMemoryRemovable.class.getName());
if (isRegisteredService) {
classCreatorBuilder.interfaces(AutoCloseable.class);
}
Expand Down Expand Up @@ -524,6 +518,16 @@ public void handleAiServices(AiServicesRecorder recorder,
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
mc.returnVoid();
}

{
MethodCreator mc = classCreator.getMethodCreator(
MethodDescriptor.ofMethod(implClassName, "remove", void.class, Object[].class));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_REMOVE_CHAT_MEMORY_IDS, contextHandle,
mc.getMethodParam(0));
mc.returnVoid();
}

}
perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
// make the constructor accessible reflectively since that is how we create the instance
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.quarkiverse.langchain4j.deployment;

import static io.quarkus.runtime.annotations.ConfigPhase.BUILD_TIME;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;
import io.smallrye.config.WithDefault;

@ConfigRoot(phase = BUILD_TIME)
@ConfigMapping(prefix = "quarkus.langchain4j.chat-memory")
public interface ChatMemoryBuildConfig {

/**
* Configure the type of {@link ChatMemory} that will be used by default by the default {@link ChatMemoryProvider} bean.
* <p>
* The extension provides a default bean that configures {@link ChatMemoryProvider} for use with AI services
* registered with {@link RegisterAiService}. This bean depends uses the {@code quarkus.langchain4j.chat-memory}
* configuration to set things up while also depending on the presence of a bean of type {@link ChatMemoryStore} (for which
* the extension also provides a default in the form of {@link InMemoryChatMemoryStore}).
* <p>
* If {@code token-window} is used, then the application must also provide a bean of type {@link Tokenizer}.
* <p>
* Users can choose to provide their own {@link ChatMemoryStore} bean or even their own {@link ChatMemoryProvider} bean
* if full control over the details is needed.
*/
@WithDefault("MESSAGE_WINDOW")
Type type();

enum Type {
MESSAGE_WINDOW,
TOKEN_WINDOW
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.quarkiverse.langchain4j.deployment;

import java.util.function.Function;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.ClassType;

import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.quarkiverse.langchain4j.runtime.ChatMemoryRecorder;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryConfig;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;

public class ChatMemoryProcessor {

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void setupBeans(ChatMemoryBuildConfig buildConfig, ChatMemoryConfig runtimeConfig,
ChatMemoryRecorder recorder,
BuildProducer<SyntheticBeanBuildItem> syntheticBeanProducer) {

Function<SyntheticCreationalContext<ChatMemoryProvider>, ChatMemoryProvider> fun;

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(ChatMemoryProvider.class)
.setRuntimeInit()
.addInjectionPoint(ClassType.create(ChatMemoryStore.class))
.scope(ApplicationScoped.class)
.defaultBean();

if (buildConfig.type() == ChatMemoryBuildConfig.Type.MESSAGE_WINDOW) {
fun = recorder.messageWindow(runtimeConfig);
} else if (buildConfig.type() == ChatMemoryBuildConfig.Type.TOKEN_WINDOW) {
configurator.addInjectionPoint(ClassType.create(Tokenizer.class));
fun = recorder.tokenWindow(runtimeConfig);
} else {
throw new IllegalStateException(
"Invalid configuration '" + buildConfig.type() + "' used in 'quarkus.langchain4j.chat-memory.type'");
}
configurator.createWith(fun);

syntheticBeanProducer.produce(configurator.done());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ public class Langchain4jDotNames {

static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatMemoryProviderSupplier.class);
static final DotName BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class);
static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.NoChatMemoryProviderSupplier.class);

static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.quarkiverse.langchain4j;

import java.util.List;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;

/**
* Allows the application to manually control when a {@link ChatMemory} should be removed from the underlying
* {@link ChatMemoryStore}.
*/
public final class ChatMemoryRemover {

private static final Object[] EMPTY_OBJECT_ARRAY = new Object[0];

private ChatMemoryRemover() {
}

/**
* @param aiService The bean that implements the AI Service annotated with {@link RegisterAiService}
* @param memoryId The object used as memory IDs for which the corresponding {@link ChatMemory} should be removed
*/
public static void remove(Object aiService, Object memoryId) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like using two objects here as users could easily swap them, but I don't have a better way of handling it...

if (aiService instanceof ChatMemoryRemovable r) {
r.remove(memoryId);
}
}

/**
* @param aiService The bean that implements the AI Service annotated with {@link RegisterAiService}
* @param memoryIds The objects used as memory IDs for which the corresponding {@link ChatMemory} should be removed
*/
public static void remove(Object aiService, List<Object> memoryIds) {
if (aiService instanceof ChatMemoryRemovable r) {
r.remove(memoryIds.toArray(EMPTY_OBJECT_ARRAY));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import io.quarkiverse.langchain4j.audit.AuditService;

/**
Expand Down Expand Up @@ -44,27 +47,29 @@

/**
* Tool classes to use. All tools are expected to be CDI beans.
* <p>
* NOTE: when this is used, either a {@link ChatMemoryProvider} bean must be present in the application, or a custom
* {@link Supplier<ChatMemoryProvider>} must be set.
*/
Class<?>[] tools() default {};

/**
* Configures the way to obtain the {@link ChatMemoryProvider} to use.
* By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link ChatMemoryProvider} instance is needed, a custom implementation of
* {@link Supplier<ChatMemoryProvider>} needs to be provided.
* Configures the way to obtain the {@link ChatMemoryProvider}.
* <p>
* Be default, Quarkus configures a {@link ChatMemoryProvider} bean that uses a {@link InMemoryChatMemoryStore} bean
* as the backing store. The default type for the actual {@link ChatMemory} is {@link MessageWindowChatMemory}
* and it is configured with the value of the {@code quarkus.langchain4j.chat-memory.memory-window.max-messages}
* configuration property (which default to 10) as a way of limiting the number of messages in each chat.
* <p>
* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then
* set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier}
* If the application provides its own {@link ChatMemoryProvider} bean, that takes precedence over what Quarkus provides as
* the default.
* <p>
* If the application provides an implementation of {@link ChatMemoryStore}, then that is used instead of the default
* {@link InMemoryChatMemoryStore}.
* <p>
* In the most advances case, an arbitrary {@link ChatMemoryProvider} can be used by having a custom
* {@code Supplier<ChatMemoryProvider>} configured in this property.
* {@link Supplier<ChatMemoryProvider>} needs to be provided.
* <p>
* NOTE: when {@link tools} is set, the default is changed to {@link BeanChatMemoryProviderSupplier} which means that a
* bean a {@link ChatMemoryProvider} bean must be present. The alternative in this case is to set a custom
* {@link Supplier<ChatMemoryProvider>}.
*/
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class;
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanChatMemoryProviderSupplier.class;

/**
* Configures the way to obtain the {@link Retriever} to use (when using RAG).
Expand Down Expand Up @@ -98,8 +103,11 @@ public ChatLanguageModel get() {
}

/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. If the bean does
* not exist, Quarkus will fail at build time.
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean.
* Be default, Quarkus configures an {@link ChatMemoryProvider} by using an {@link InMemoryChatMemoryStore}
* as the backing store while using {@link MessageWindowChatMemory} with the value of
* configuration property {@code quarkus.langchain4j.chat-memory.memory-window.max-messages} (which default to 10)
* as a way of limiting the number of messages in each chat.
*/
final class BeanChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

Expand All @@ -109,29 +117,6 @@ public ChatMemoryProvider get() {
}
}

/**
* Marker that is used to tell Quarkus to use the {@link ChatMemoryProvider} that the user has configured as a CDI bean.
* If no such bean exists, then no memory will be used.
*/
final class BeanIfExistsChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

@Override
public ChatMemoryProvider get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no chat memory should be used
*/
final class NoChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

@Override
public ChatMemoryProvider get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean
*/
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
.equals(info.getChatMemoryProviderSupplierClassName())) {
quarkusAiServices.chatMemoryProvider(creationalContext.getInjectedReference(
ChatMemoryProvider.class));
} else if (RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class.getName()
.equals(info.getChatMemoryProviderSupplierClassName())) {
Instance<ChatMemoryProvider> instance = creationalContext
.getInjectedReference(CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.chatMemoryProvider(instance.get());
}
} else {
Supplier<? extends ChatMemoryProvider> supplier = (Supplier<? extends ChatMemoryProvider>) Thread
.currentThread().getContextClassLoader()
Expand Down
Loading