diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 0a56e2cee..d4939f696 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -23,6 +23,7 @@ import java.util.stream.Collectors; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Instance; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTarget; @@ -156,7 +157,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, needChatModelBean = true; } - DotName chatMemoryProviderSupplierClassDotName = null; + DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER; AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier"); if (chatMemoryProviderSupplierValue != null) { chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name(); @@ -182,7 +183,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, .collect(Collectors.toList()); } - DotName retrieverSupplierClassDotName = null; + DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER; AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier"); if (retrieverSupplierValue != null) { retrieverSupplierClassDotName = retrieverSupplierValue.asClass().name(); @@ -276,12 +277,24 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if (Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) { configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER)); needsChatMemoryProviderBean = true; + } else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString() + .equals(chatMemoryProviderSupplierClassName)) { + configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class), + new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null)); + needsChatMemoryProviderBean = true; } if (Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER.toString().equals(retrieverSupplierClassName)) { configurator.addInjectionPoint(ParameterizedType.create(Langchain4jDotNames.RETRIEVER, new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null)); needsRetrieverBean = true; + } else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString() + .equals(retrieverSupplierClassName)) { + configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class), + new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER, + new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) }, + null)); + needsRetrieverBean = true; } syntheticBeanProducer.produce(configurator.done()); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java index 2f41309d2..dad8425dc 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java @@ -45,6 +45,8 @@ public class Langchain4jDotNames { static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple( RegisterAiService.BeanChatMemoryProviderSupplier.class); + static final DotName BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple( + RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class); static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple( RegisterAiService.NoChatMemoryProviderSupplier.class); @@ -54,4 +56,7 @@ public class Langchain4jDotNames { static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple( RegisterAiService.BeanRetrieverSupplier.class); + static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple( + RegisterAiService.BeanIfExistsRetrieverSupplier.class); + } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 020119ee9..6266ba6a3 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -13,7 +13,6 @@ import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.retriever.Retriever; import dev.langchain4j.service.AiServices; @@ -21,8 +20,8 @@ * Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by * using the class as a CDI bean. * Under the hood Langchain4j's {@link AiServices#builder(Class)} is called - * while also providing the builder with the proper {@link ChatLanguageModel}, {@code tools} beans, - * {@link ChatMemory} or {@link ChatMemoryProvider}, {@link ModerationModel} and {@link Retriever}. + * while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional), + * {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist). *

* NOTE: The resulting CDI bean is {@link ApplicationScoped}. */ @@ -33,7 +32,8 @@ /** * Configures the way to obtain the {@link ChatLanguageModel} to use. * If not configured, the default CDI bean implementing the model is looked up. - * Such a bean provided automatically by extensions such {@code quarkus-langchain4j-openai} and + * Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai}, + * {@code quarkus-langchain4j-azure-openai} or * {@code quarkus-langchain4j-hugging-face} */ Class> chatLanguageModelSupplier() default BeanChatLanguageModelSupplier.class; @@ -47,14 +47,15 @@ /** * Configures the way to obtain the {@link ChatMemoryProvider} to use. - * By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}. + * By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}, but will fall back to not using + * any memory if no such bean exists. * If an arbitrary {@link ChatMemoryProvider} instance is needed, a custom implementation of * {@link Supplier} needs to be provided. *

* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then * set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier} */ - Class> chatMemoryProviderSupplier() default NoChatMemoryProviderSupplier.class; + Class> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class; /** * Configures the way to obtain the {@link Retriever} to use (when using RAG). @@ -79,7 +80,8 @@ public ChatLanguageModel get() { } /** - * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean + * Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. If the bean does + * not exist, Quarkus will fail at build time. */ final class BeanChatMemoryProviderSupplier implements Supplier { @@ -89,6 +91,18 @@ public ChatMemoryProvider get() { } } + /** + * Marker that is used to tell Quarkus to use the {@link ChatMemoryProvider} that the user has configured as a CDI bean. + * If no such bean exists, then no memory will be used. + */ + final class BeanIfExistsChatMemoryProviderSupplier implements Supplier { + + @Override + public ChatMemoryProvider get() { + throw new UnsupportedOperationException("should never be called"); + } + } + /** * Marker class to indicate that no chat memory should be used */ @@ -111,6 +125,18 @@ public Retriever get() { } } + /** + * Marker that is used to tell Quarkus to use the {@link Retriever} that the user has configured as a CDI bean. + * If no such bean exists, then no retriever will be used. + */ + final class BeanIfExistsRetrieverSupplier implements Supplier> { + + @Override + public Retriever get() { + throw new UnsupportedOperationException("should never be called"); + } + } + /** * Marker class to indicate that no retriever should be used */ diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index a05ba1cb6..f38b4ebf1 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -10,8 +10,10 @@ import java.util.function.Function; import java.util.function.Supplier; +import jakarta.enterprise.inject.Instance; import jakarta.enterprise.util.TypeLiteral; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.retriever.Retriever; @@ -26,6 +28,11 @@ @Recorder public class AiServicesRecorder { + private static final TypeLiteral> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() { + }; + private static final TypeLiteral>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() { + }; + // the key is the interface's class name private static final Map metadata = new HashMap<>(); @@ -93,6 +100,13 @@ public T apply(SyntheticCreationalContext creationalContext) { .equals(info.getChatMemoryProviderSupplierClassName())) { quarkusAiServices.chatMemoryProvider(creationalContext.getInjectedReference( ChatMemoryProvider.class)); + } else if (RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class.getName() + .equals(info.getChatMemoryProviderSupplierClassName())) { + Instance instance = creationalContext + .getInjectedReference(CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL); + if (instance.isResolvable()) { + quarkusAiServices.chatMemoryProvider(instance.get()); + } } else { Supplier supplier = (Supplier) Thread .currentThread().getContextClassLoader() @@ -107,6 +121,13 @@ public T apply(SyntheticCreationalContext creationalContext) { .equals(info.getRetrieverSupplierClassName())) { quarkusAiServices.retriever(creationalContext.getInjectedReference(new TypeLiteral<>() { })); + } else if (RegisterAiService.BeanIfExistsRetrieverSupplier.class.getName() + .equals(info.getRetrieverSupplierClassName())) { + Instance> instance = creationalContext + .getInjectedReference(RETRIEVER_INSTANCE_TYPE_LITERAL); + if (instance.isResolvable()) { + quarkusAiServices.retriever(instance.get()); + } } else { @SuppressWarnings("rawtypes") Supplier supplier = (Supplier) Thread diff --git a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MyAiService.java b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MyAiService.java index 05b50587a..bb5287fdb 100644 --- a/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MyAiService.java +++ b/docs/modules/ROOT/examples/io/quarkiverse/langchain4j/samples/MyAiService.java @@ -5,14 +5,13 @@ import io.quarkiverse.langchain4j.RegisterAiService; @RegisterAiService( // <1> - chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, // <2> - tools = EmailService.class // <3> + tools = EmailService.class // <2> ) public interface MyAiService { - @SystemMessage("You are a professional poet") // <4> + @SystemMessage("You are a professional poet") // <3> @UserMessage(""" - Write a poem about {topic}. The poem should be {lines} lines long. Then send this poem by email. // <5> + Write a poem about {topic}. The poem should be {lines} lines long. Then send this poem by email. // <4> """) - String writeAPoem(String topic, int lines); // <6> -} \ No newline at end of file + String writeAPoem(String topic, int lines); // <5> +} diff --git a/docs/modules/ROOT/pages/agent-and-tools.adoc b/docs/modules/ROOT/pages/agent-and-tools.adoc index bcd1e8211..dd3f9b4d3 100644 --- a/docs/modules/ROOT/pages/agent-and-tools.adoc +++ b/docs/modules/ROOT/pages/agent-and-tools.adoc @@ -170,9 +170,7 @@ public class AssistantWithToolsResource { } } - @RegisterAiService( - tools = Calculator.class, - chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class) // <3> + @RegisterAiService(tools = Calculator.class) // <3> public interface Assistant { String chat(String userMessage); @@ -181,7 +179,7 @@ public class AssistantWithToolsResource { ---- <1> Declare a CDI bean that provides three different tools <2> Declare a CDI bean for providing a simple in-memory message store -<3> Register an AiService that responds to a user's request and has access to the calculator tools, while also being able to keep track of the session's messages using the CDI message store declared above +<3> Register an AiService that responds to a user's request and has access to the calculator tools. This service is also able to keep track of the session's messages using the CDI message store declared above. <4> Declare an HTTP endpoint that retrieves the user's question via a query parameter and simply responds with chatbot's response Now, if we ask the chatbot `What is the square root of the sum of the numbers of letters in the words "hello" and "world"` via: diff --git a/docs/modules/ROOT/pages/ai-services.adoc b/docs/modules/ROOT/pages/ai-services.adoc index 8e0d12bd1..77125a6de 100644 --- a/docs/modules/ROOT/pages/ai-services.adoc +++ b/docs/modules/ROOT/pages/ai-services.adoc @@ -168,53 +168,48 @@ public class MyChatModelSupplier implements Supplier { As LLMs are stateless, the memory — comprising the interaction context — must be exchanged each time. To prevent storing excessive messages, it's crucial to evict older messages. -The `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the memory provider: +The `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the memory provider. The default value of this annotation is `RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class` +which means that the `AiService` will use whatever `ChatMemoryProvider` bean is configured by the application, while falling back to no memory if no such bean exists. +An example of such a bean is: [source,java] ---- -@RegisterAiService( - chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class) +include::{examples-dir}/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java[] ---- -It can be a class implementing `Supplier`, such as: +Notice that the messages are deleted when the scope terminates (as it will call the `close` method). + +NOTE: It is recommended that the bean use the `@RequestScoped` scope or a scope not shared between users. + +Users can provide their own custom `ChatMemoryProvider` for use in the AiService by implementing `Supplier`, such as: [source,java] ---- include::{examples-dir}/io/quarkiverse/langchain4j/samples/MySmallMemoryProvider.java[] ---- -In cases involving multiple users, ensure each user has a unique memory ID and pass this ID to the AI method: +and configuring the AiService as so: [source,java] ---- -String chat(@MemoryId int memoryId, @UserMessage String userMessage); +@RegisterAiService( + chatMemoryProviderSupplier = MySmallMemoryProvider.class) ---- -Also, remember to clear out users to prevent memory issues. - TIP: For non-memory-reliant LLM interactions, you may skip memory configuration. -Alternatively, you can use the `BeanChatMemoryProviderSupplier` class to use a CDI bean as memory provider: - -[source,java] ----- -include::{examples-dir}/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java[] ----- +IMPORTANT: When using tools, you need a memory of at least 3 messages to cover the tools interaction. -Notice that the messages are deleted when the scope terminates (as it will call the `close` method). +=== @MemoryId -This bean is then referenced in the `@RegisterAiService` annotation using the `RegisterAiService.BeanChatMemoryProviderSupplier.class` value: +In cases involving multiple users, ensure each user has a unique memory ID and pass this ID to the AI method: [source,java] ---- -@RegisterAiService( - chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class -) +String chat(@MemoryId int memoryId, @UserMessage String userMessage); ---- -NOTE: It is recommended that the bean use the `@RequestScoped` scope or a scope not shared between users. - -IMPORTANT: When using tools, you need a memory of at least 3 messages to cover the tools interaction. +Also, remember to clear out users to prevent memory issues. == Configuring Tools @@ -237,9 +232,7 @@ The `@Tool` annotation can provide a description of the action, aiding the LLM i [source,java] ---- -@RegisterAiService( - chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, - tools = {TransactionRepository.class, CustomerRepository.class }) +@RegisterAiService(tools = {TransactionRepository.class, CustomerRepository.class }) ---- IMPORTANT: Ensure you configure the memory provider when using tools. diff --git a/docs/modules/ROOT/pages/index.adoc b/docs/modules/ROOT/pages/index.adoc index 82d2981d1..e5ee71a4a 100644 --- a/docs/modules/ROOT/pages/index.adoc +++ b/docs/modules/ROOT/pages/index.adoc @@ -59,12 +59,11 @@ Once you've added the dependency and configuration, the next step involves creat include::{examples-dir}/io/quarkiverse/langchain4j/samples/MyAiService.java[] ---- <1> The `@RegisterAiService` annotation registers the _AI service_. -<2> The `chatMemoryProviderSupplier` attribute specifies the _chat memory_ provider, managing how the LLM retains conversation history (the "context"). -<3> The `tools` attribute defines the _tools_ the LLM can employ. +<2> The `tools` attribute defines the _tools_ the LLM can employ. During interaction, the LLM can invoke these tools and reflect on their output. -<4> The `@SystemMessage` annotation registers a _system message_, setting the initial context or "scope". -<5> The `@UserMessage` annotation serves as the _prompt_. -<6> The method invokes the LLM, initiating an exchange between the LLM and the application, beginning with the system message and then the user message. Your application triggers this method and receives the response. +<3> The `@SystemMessage` annotation registers a _system message_, setting the initial context or "scope". +<4> The `@UserMessage` annotation serves as the _prompt_. +<5> The method invokes the LLM, initiating an exchange between the LLM and the application, beginning with the system message and then the user message. Your application triggers this method and receives the response. == Advantages over vanilla Langchain4j diff --git a/docs/modules/ROOT/pages/retrievers.adoc b/docs/modules/ROOT/pages/retrievers.adoc index f724b851e..5cd80ce02 100644 --- a/docs/modules/ROOT/pages/retrievers.adoc +++ b/docs/modules/ROOT/pages/retrievers.adoc @@ -66,15 +66,7 @@ Configure the maximum number of documents to retrieve (e.g., 20 in the example) Make sure that the number of documents is not too high (or document too large). More document you have, more data you are adding to the LLM context, and you may exceed the limit. -To use the retriever in your AI service, configure it as a CDI bean: - -[source,java] ----- -@RegisterAiService(retrieverSupplier = RegisterAiService.BeanRetrieverSupplier.class) -public interface MyAiService { -// ... -} ----- +A retriever is used by default in your AI service, simply by having said retriever configured as a CDI bean, Alternatively, implement a class implementing `Supplier>` if you prefer not to expose the retriever as a CDI bean. -Then, configure the `retrieverSupplier` attribute to point to your implementation. \ No newline at end of file +Then, configure the `retrieverSupplier` attribute to point to your implementation. diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java new file mode 100644 index 000000000..f4586cffb --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/BeanDeclarativeAiServicesTest.java @@ -0,0 +1,251 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static dev.langchain4j.data.message.ChatMessageDeserializer.messagesFromJson; +import static dev.langchain4j.data.message.ChatMessageSerializer.messagesToJson; +import static dev.langchain4j.data.message.ChatMessageType.AI; +import static dev.langchain4j.data.message.ChatMessageType.USER; +import static org.acme.examples.aiservices.MessageAssertUtils.assertMultipleRequestMessage; +import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.tuple; +import static org.assertj.core.api.InstanceOfAssertFactories.list; +import static org.assertj.core.api.InstanceOfAssertFactories.map; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.assertj.core.api.InstanceOfAssertFactory; +import org.assertj.core.api.ListAssert; +import org.assertj.core.api.MapAssert; +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import dev.langchain4j.store.memory.chat.ChatMemoryStore; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.openai.test.WiremockUtils; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; + +public class BeanDeclarativeAiServicesTest { + + private static final int WIREMOCK_PORT = 8089; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + private static final InstanceOfAssertFactory> MAP_STRING_STRING = map(String.class, + String.class); + private static final InstanceOfAssertFactory> LIST_MAP = list(Map.class); + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + private static MessageWindowChatMemory createChatMemory() { + return MessageWindowChatMemory.withMaxMessages(10); + } + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub()); + } + + public static class ChatMemoryProviderProducer { + + @Singleton + ChatMemoryProvider chatMemory(ChatMemoryStore store) { + return memoryId -> MessageWindowChatMemory.builder() + .id(memoryId) + .maxMessages(10) + .chatMemoryStore(store) + .build(); + } + } + + @Singleton + public static class CustomChatMemoryStore implements ChatMemoryStore { + + // emulating persistent storage + private final Map persistentStorage = new HashMap<>(); + + @Override + public List getMessages(Object memoryId) { + return messagesFromJson(persistentStorage.get(memoryId)); + } + + @Override + public void updateMessages(Object memoryId, List messages) { + persistentStorage.put(memoryId, messagesToJson(messages)); + } + + @Override + public void deleteMessages(Object memoryId) { + persistentStorage.remove(memoryId); + } + } + + @RegisterAiService + interface ChatWithSeparateMemoryForEachUser { + + String chat(@MemoryId int memoryId, @UserMessage String userMessage); + } + + @Inject + ChatWithSeparateMemoryForEachUser chatWithSeparateMemoryForEachUser; + + @Test + void should_keep_separate_chat_memory_for_each_user_in_store() throws IOException { + + ChatMemoryStore store = Arc.container().instance(ChatMemoryStore.class).get(); + + int firstMemoryId = 1; + int secondMemoryId = 2; + + /* **** First request for user 1 **** */ + String firstMessageFromFirstUser = "Hello, my name is Klaus"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Nice to meet you Klaus")); + String firstAiResponseToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, firstMessageFromFirstUser); + + // assert response + assertThat(firstAiResponseToFirstUser).isEqualTo("Nice to meet you Klaus"); + + // assert request + assertSingleRequestMessage(getRequestAsMap(), firstMessageFromFirstUser); + + // assert chat memory + assertThat(store.getMessages(firstMemoryId)).hasSize(2) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser)); + + /* **** First request for user 2 **** */ + wireMockServer.resetRequests(); + + String firstMessageFromSecondUser = "Hello, my name is Francine"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Nice to meet you Francine")); + String firstAiResponseToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId, firstMessageFromSecondUser); + + // assert response + assertThat(firstAiResponseToSecondUser).isEqualTo("Nice to meet you Francine"); + + // assert request + assertSingleRequestMessage(getRequestAsMap(), firstMessageFromSecondUser); + + // assert chat memory + assertThat(store.getMessages(secondMemoryId)).hasSize(2) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser)); + + /* **** Second request for user 1 **** */ + wireMockServer.resetRequests(); + + String secondsMessageFromFirstUser = "What is my name?"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Your name is Klaus")); + String secondAiMessageToFirstUser = chatWithSeparateMemoryForEachUser.chat(firstMemoryId, secondsMessageFromFirstUser); + + // assert response + assertThat(secondAiMessageToFirstUser).contains("Klaus"); + + // assert request + assertMultipleRequestMessage(getRequestAsMap(), + List.of( + new MessageAssertUtils.MessageContent("user", firstMessageFromFirstUser), + new MessageAssertUtils.MessageContent("assistant", firstAiResponseToFirstUser), + new MessageAssertUtils.MessageContent("user", secondsMessageFromFirstUser))); + + // assert chat memory + assertThat(store.getMessages(firstMemoryId)).hasSize(4) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromFirstUser), tuple(AI, firstAiResponseToFirstUser), + tuple(USER, secondsMessageFromFirstUser), tuple(AI, secondAiMessageToFirstUser)); + + /* **** Second request for user 2 **** */ + wireMockServer.resetRequests(); + + String secondsMessageFromSecondUser = "What is my name?"; + wireMockServer.stubFor(WiremockUtils.chatCompletionsMessageContent(Optional.empty(), + "Your name is Francine")); + String secondAiMessageToSecondUser = chatWithSeparateMemoryForEachUser.chat(secondMemoryId, + secondsMessageFromSecondUser); + + // assert response + assertThat(secondAiMessageToSecondUser).contains("Francine"); + + // assert request + assertMultipleRequestMessage(getRequestAsMap(), + List.of( + new MessageAssertUtils.MessageContent("user", firstMessageFromSecondUser), + new MessageAssertUtils.MessageContent("assistant", firstAiResponseToSecondUser), + new MessageAssertUtils.MessageContent("user", secondsMessageFromSecondUser))); + + // assert chat memory + assertThat(store.getMessages(secondMemoryId)).hasSize(4) + .extracting(ChatMessage::type, ChatMessage::text) + .containsExactly(tuple(USER, firstMessageFromSecondUser), tuple(AI, firstAiResponseToSecondUser), + tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser)); + } + + private Map getRequestAsMap() throws IOException { + return getRequestAsMap(getRequestBody()); + } + + private Map getRequestAsMap(byte[] body) throws IOException { + return mapper.readValue(body, MAP_TYPE_REF); + } + + private byte[] getRequestBody() { + assertThat(wireMockServer.getAllServeEvents()).hasSize(1); + ServeEvent serveEvent = wireMockServer.getAllServeEvents().get(0); // this works because we reset requests for Wiremock before each test + return getRequestBody(serveEvent); + } + + private byte[] getRequestBody(ServeEvent serveEvent) { + LoggedRequest request = serveEvent.getRequest(); + assertThat(request.getBody()).isNotEmpty(); + return request.getBody(); + } +} diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java index a63bd0a75..d5d2dd611 100644 --- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java @@ -95,7 +95,7 @@ void setup() { wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub()); } - @RegisterAiService + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) interface Assistant { String chat(String message); @@ -118,7 +118,7 @@ enum Sentiment { NEGATIVE } - @RegisterAiService + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) interface SentimentAnalyzer { @UserMessage("Analyze sentiment of {it}") diff --git a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/Bot.java b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/Bot.java index 27bfb657d..bc7ddd141 100644 --- a/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/Bot.java +++ b/samples/chatbot/src/main/java/io/quarkiverse/langchain4j/sample/chatbot/Bot.java @@ -5,7 +5,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; -@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, retrieverSupplier = RegisterAiService.BeanRetrieverSupplier.class) +@RegisterAiService public interface Bot { @SystemMessage(""" diff --git a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java index 9e0d9bdec..06bace675 100644 --- a/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java +++ b/samples/email-a-poem/src/main/java/io/quarkiverse/langchain4j/sample/MyAiService.java @@ -4,7 +4,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; -@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, tools = EmailService.class) +@RegisterAiService(tools = EmailService.class) public interface MyAiService { /** diff --git a/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java b/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java index 6e29fccec..b3b53b6f1 100644 --- a/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java +++ b/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java @@ -8,8 +8,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; -@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, tools = { - TransactionRepository.class, CustomerRepository.class }) +@RegisterAiService(tools = { TransactionRepository.class, CustomerRepository.class }) public interface FraudDetectionAi { @SystemMessage(""" diff --git a/samples/review-triage/src/main/resources/META-INF/resources/index.html b/samples/review-triage/src/main/resources/META-INF/resources/index.html index 6590caf45..00931d106 100644 --- a/samples/review-triage/src/main/resources/META-INF/resources/index.html +++ b/samples/review-triage/src/main/resources/META-INF/resources/index.html @@ -11,7 +11,7 @@ @@ -61,4 +61,4 @@ - \ No newline at end of file +