Skip to content

Commit

Permalink
Make @RegisterAiService beans request scoped by default
Browse files Browse the repository at this point in the history
This is done because otherwise the chat memory
does not get cleared properly.

Fixes: #95
  • Loading branch information
geoand committed Dec 5, 2023
1 parent 51fc36e commit 34386f7
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;

import org.jboss.jandex.AnnotationInstance;
Expand Down Expand Up @@ -60,6 +59,8 @@
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
Expand Down Expand Up @@ -203,13 +204,16 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo);
ScopeInfo cdiScope = declaredScope != null ? declaredScope.getInfo() : BuiltinScope.REQUEST.getInfo();

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName));
retrieverSupplierClassDotName, cdiScope));
}

if (needChatModelBean) {
Expand Down Expand Up @@ -271,7 +275,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName)))
.setRuntimeInit()
.scope(ApplicationScoped.class);
.scope(bi.getCdiScope());
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL));
needsChatModelBean = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;

import io.quarkus.arc.processor.ScopeInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
Expand All @@ -18,16 +19,19 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {

private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final ScopeInfo cdiScope;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName) {
DotName retrieverSupplierClassDotName,
ScopeInfo cdiScope) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.cdiScope = cdiScope;
}

public ClassInfo getServiceClassInfo() {
Expand All @@ -49,4 +53,8 @@ public DotName getChatMemoryProviderSupplierClassDotName() {
public DotName getRetrieverSupplierClassDotName() {
return retrieverSupplierClassDotName;
}

public ScopeInfo getCdiScope() {
return cdiScope;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import java.lang.annotation.Target;
import java.util.function.Supplier;

import jakarta.enterprise.context.ApplicationScoped;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -23,7 +21,9 @@
* while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional),
* {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist).
* <p>
* NOTE: The resulting CDI bean is {@link ApplicationScoped}.
* NOTE: The resulting CDI bean is {@link jakarta.enterprise.context.RequestScoped} be default. If you need to change the scope,
* simply annotate the class with a CDI scope.
* CAUTION: When using anything other than the request scope, you need to be very careful with the chat memory implementation.
* <p>
* NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated
* for the method invocations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public void deleteMessages(Object memoryId) {
}

@RegisterAiService
@Singleton
interface ChatWithSeparateMemoryForEachUser {

String chat(@MemoryId int memoryId, @UserMessage String userMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;
import java.util.Optional;

import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

Expand Down Expand Up @@ -105,6 +106,7 @@ interface Assistant {
Assistant assistant;

@Test
@ActivateRequestContext
public void test_simple_instruction_with_single_argument_and_no_annotations() throws IOException {
String result = assistant.chat("Tell me a joke about developers");
assertThat(result).isNotBlank();
Expand All @@ -129,6 +131,7 @@ interface SentimentAnalyzer {
SentimentAnalyzer sentimentAnalyzer;

@Test
@ActivateRequestContext
void test_extract_enum() throws IOException {
wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), "POSITIVE"));

Expand Down Expand Up @@ -213,6 +216,7 @@ interface AssistantWithCalculator extends Assistant {
AssistantWithCalculator assistantWithCalculator;

@Test
@ActivateRequestContext
void should_execute_tool_then_answer() throws IOException {
var firstResponse = """
{
Expand Down Expand Up @@ -308,6 +312,7 @@ interface ChatWithSeparateMemoryForEachUser {
ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser;

@Test
@ActivateRequestContext
void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException {

ChatMemoryStore store = Arc.container().instance(ChatMemoryStore.class).get();
Expand Down
6 changes: 5 additions & 1 deletion samples/chatbot/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-websockets</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-context-propagation</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
Expand Down Expand Up @@ -146,4 +150,4 @@
</profile>
</profiles>

</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jakarta.websocket.server.ServerEndpoint;

import io.smallrye.mutiny.infrastructure.Infrastructure;
import org.eclipse.microprofile.context.ManagedExecutor;

@ServerEndpoint("/chatbot")
public class ChatBotWebSocket {
Expand All @@ -17,9 +18,12 @@ public class ChatBotWebSocket {
@Inject
ChatMemoryBean chatMemoryBean;

@Inject
ManagedExecutor managedExecutor;

@OnOpen
public void onOpen(Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
managedExecutor.execute(() -> {
String response = bot.chat(session, "hello");
try {
session.getBasicRemote().sendText(response);
Expand All @@ -36,7 +40,7 @@ void onClose(Session session) {

@OnMessage
public void onMessage(String message, Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
managedExecutor.execute(() -> {
String response = bot.chat(session, message);
try {
session.getBasicRemote().sendText(response);
Expand Down
6 changes: 5 additions & 1 deletion samples/csv-chatbot/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-websockets</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-context-propagation</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
Expand Down Expand Up @@ -155,4 +159,4 @@
</profile>
</profiles>

</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jakarta.websocket.*;
import jakarta.websocket.server.ServerEndpoint;

import io.smallrye.mutiny.infrastructure.Infrastructure;
import org.eclipse.microprofile.context.ManagedExecutor;

@ServerEndpoint("/chatbot")
public class ChatBotWebSocket {
Expand All @@ -17,9 +17,12 @@ public class ChatBotWebSocket {
@Inject
ChatMemoryBean chatMemoryBean;

@Inject
ManagedExecutor managedExecutor;

@OnOpen
public void onOpen(Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
managedExecutor.execute(() -> {
String response = bot.chat(session, "hello");
try {
session.getBasicRemote().sendText(response);
Expand All @@ -36,7 +39,7 @@ void onClose(Session session) {

@OnMessage
public void onMessage(String message, Session session) {
Infrastructure.getDefaultExecutor().execute(() -> {
managedExecutor.execute(() -> {
String response = bot.chat(session, message);
try {
session.getBasicRemote().sendText(response);
Expand Down

0 comments on commit 34386f7

Please sign in to comment.