Skip to content

Commit

Permalink
Make @RegisterAiService beans request scoped by default
Browse files Browse the repository at this point in the history
This is done because otherwise the chat memory
does not get cleared properly.

Furthermore, add a way to remove memory entries when the service goes out of scope

Fixes: #95
  • Loading branch information
geoand committed Dec 6, 2023
1 parent 574e0e1 commit a736eed
Show file tree
Hide file tree
Showing 15 changed files with 409 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;

import org.jboss.jandex.AnnotationInstance;
Expand All @@ -50,6 +49,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
Expand All @@ -60,6 +60,8 @@
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand Down Expand Up @@ -101,6 +103,9 @@ public class AiServicesProcessor {
private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(
AiServiceMethodImplementationSupport.class,
"implement", Object.class, AiServiceMethodImplementationSupport.Input.class);

private static final MethodDescriptor QUARKUS_AI_SERVICES_CONTEXT_CLOSE = MethodDescriptor.ofMethod(
QuarkusAiServiceContext.class, "close", void.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

@BuildStep
Expand Down Expand Up @@ -211,14 +216,18 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer);
}

BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo);
ScopeInfo cdiScope = declaredScope != null ? declaredScope.getInfo() : BuiltinScope.REQUEST.getInfo();

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName,
auditServiceClassSupplierName));
auditServiceClassSupplierName,
cdiScope));
}

if (needChatModelBean) {
Expand Down Expand Up @@ -285,8 +294,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName,
auditServiceClassSupplierName)))
.destroyer(DeclarativeAiServiceBeanDestroyer.class)
.setRuntimeInit()
.scope(ApplicationScoped.class);
.scope(bi.getCdiScope());
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL));
needsChatModelBean = true;
Expand Down Expand Up @@ -403,8 +413,10 @@ public void handleAiServices(AiServicesRecorder recorder,
Set<String> detectedForCreate = new HashSet<>(nameToUsed.keySet());
addCreatedAware(index, detectedForCreate);
addIfacesWithMessageAnns(index, detectedForCreate);
detectedForCreate.addAll(declarativeAiServiceItems.stream().map(bi -> bi.getServiceClassInfo().name().toString())
.collect(Collectors.toList()));
Set<String> registeredAiServiceClassNames = declarativeAiServiceItems.stream()
.map(bi -> bi.getServiceClassInfo().name().toString()).collect(
Collectors.toUnmodifiableSet());
detectedForCreate.addAll(registeredAiServiceClassNames);

Set<ClassInfo> ifacesForCreate = new HashSet<>();
for (String className : detectedForCreate) {
Expand Down Expand Up @@ -453,12 +465,18 @@ public void handleAiServices(AiServicesRecorder recorder,
methodsToImplement.add(method);
}

String implClassName = iface.name().toString() + "$$QuarkusImpl";
try (ClassCreator classCreator = ClassCreator.builder()
String ifaceName = iface.name().toString();
String implClassName = ifaceName + "$$QuarkusImpl";
boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName);

ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(classOutput)
.className(implClassName)
.interfaces(iface.name().toString())
.build()) {
.interfaces(ifaceName);
if (isRegisteredService) {
classCreatorBuilder.interfaces(AutoCloseable.class);
}
try (ClassCreator classCreator = classCreatorBuilder.build()) {

FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
Expand All @@ -480,7 +498,7 @@ public void handleAiServices(AiServicesRecorder recorder,
MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
mc.load(iface.name().toString()),
mc.load(ifaceName),
mc.load(methodId));
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
Expand All @@ -498,8 +516,16 @@ public void handleAiServices(AiServicesRecorder recorder,

aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
}

if (isRegisteredService) {
MethodCreator mc = classCreator.getMethodCreator(
MethodDescriptor.ofMethod(implClassName, "close", void.class));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
mc.returnVoid();
}
}
perClassMetadata.put(iface.name().toString(), new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
perClassMetadata.put(ifaceName, new AiServiceClassCreateInfo(perMethodMetadata, implClassName));
// make the constructor accessible reflectively since that is how we create the instance
reflectiveClassProducer.produce(ReflectiveClassBuildItem.builder(implClassName).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;

import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
Expand All @@ -19,18 +20,21 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName auditServiceClassSupplierDotName;
private final ScopeInfo cdiScope;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName,
DotName auditServiceClassSupplierDotName) {
DotName auditServiceClassSupplierDotName,
ScopeInfo cdiScope) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
this.cdiScope = cdiScope;
}

public ClassInfo getServiceClassInfo() {
Expand All @@ -56,4 +60,8 @@ public DotName getRetrieverSupplierClassDotName() {
public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}

public ScopeInfo getCdiScope() {
return cdiScope;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import java.lang.annotation.Target;
import java.util.function.Supplier;

import jakarta.enterprise.context.ApplicationScoped;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -24,7 +22,9 @@
* while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional),
* {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist).
* <p>
* NOTE: The resulting CDI bean is {@link ApplicationScoped}.
* NOTE: The resulting CDI bean is {@link jakarta.enterprise.context.RequestScoped} be default. If you need to change the scope,
* simply annotate the class with a CDI scope.
* CAUTION: When using anything other than the request scope, you need to be very careful with the chat memory implementation.
* <p>
* NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated
* for the method invocations.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package io.quarkiverse.langchain4j;

import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;

/**
* Extends {@link ChatMemoryProvider} to allow for removing {@link ChatMemory}
* when it is no longer needed.
*/
public interface RemovableChatMemoryProvider extends ChatMemoryProvider {

void remove(Object id);
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[]
}

Object memoryId = memoryId(createInfo, methodArgs).orElse("default");
context.usedMemoryIds.add(memoryId);

if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.Map;

import jakarta.enterprise.context.spi.CreationalContext;

import org.jboss.logging.Logger;

import io.quarkus.arc.BeanDestroyer;

public class DeclarativeAiServiceBeanDestroyer implements BeanDestroyer<AutoCloseable> {

private static final Logger log = Logger.getLogger(DeclarativeAiServiceBeanDestroyer.class);

@Override
public void destroy(AutoCloseable instance, CreationalContext<AutoCloseable> creationalContext,
Map<String, Object> params) {
try {
instance.close();
} catch (Exception e) {
log.error("Unable to close " + instance);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,44 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import dev.langchain4j.service.AiServiceContext;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.RemovableChatMemoryProvider;
import io.quarkiverse.langchain4j.audit.AuditService;

public class QuarkusAiServiceContext extends AiServiceContext {

public AuditService auditService;

public Set<Object> usedMemoryIds = ConcurrentHashMap.newKeySet();

public QuarkusAiServiceContext(Class<?> aiServiceClass) {
super(aiServiceClass);
}

/**
* This is called by the {@code close} method of AiServices registered with {@link RegisterAiService}
* when the bean's scope is closed
*/
public void close() {
removeChatMemories();
}

private void removeChatMemories() {
if (usedMemoryIds.isEmpty()) {
return;
}
RemovableChatMemoryProvider removableChatMemoryProvider = null;
if (chatMemoryProvider instanceof RemovableChatMemoryProvider) {
removableChatMemoryProvider = (RemovableChatMemoryProvider) chatMemoryProvider;
}
for (Object memoryId : usedMemoryIds) {
if (removableChatMemoryProvider != null) {
removableChatMemoryProvider.remove(memoryId);
}
chatMemories.remove(memoryId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ static class Calculator {
}

@RegisterAiService(tools = Calculator.class)
@Singleton
interface Assistant {

String chat(String message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public void deleteMessages(Object memoryId) {
}

@RegisterAiService
@Singleton
interface ChatWithSeparateMemoryForEachUser {

String chat(@MemoryId int memoryId, @UserMessage String userMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;
import java.util.Optional;

import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

Expand Down Expand Up @@ -105,6 +106,7 @@ interface Assistant {
Assistant assistant;

@Test
@ActivateRequestContext
public void test_simple_instruction_with_single_argument_and_no_annotations() throws IOException {
String result = assistant.chat("Tell me a joke about developers");
assertThat(result).isNotBlank();
Expand All @@ -129,6 +131,7 @@ interface SentimentAnalyzer {
SentimentAnalyzer sentimentAnalyzer;

@Test
@ActivateRequestContext
void test_extract_enum() throws IOException {
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), "POSITIVE"));

Expand Down Expand Up @@ -213,6 +216,7 @@ interface AssistantWithCalculator extends Assistant {
AssistantWithCalculator assistantWithCalculator;

@Test
@ActivateRequestContext
void should_execute_tool_then_answer() throws IOException {
var firstResponse = """
{
Expand Down Expand Up @@ -308,6 +312,7 @@ interface ChatWithSeparateMemoryForEachUser {
ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser;

@Test
@ActivateRequestContext
void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException {

ChatMemoryStore store = Arc.container().instance(ChatMemoryStore.class).get();
Expand Down
Loading

0 comments on commit a736eed

Please sign in to comment.