Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the need to explicitly configure retriever and memory beans for @RegisterAiService #51

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
Expand Down Expand Up @@ -156,7 +157,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
needChatModelBean = true;
}

DotName chatMemoryProviderSupplierClassDotName = null;
DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
Expand All @@ -182,7 +183,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
.collect(Collectors.toList());
}

DotName retrieverSupplierClassDotName = null;
DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER;
AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier");
if (retrieverSupplierValue != null) {
retrieverSupplierClassDotName = retrieverSupplierValue.asClass().name();
Expand Down Expand Up @@ -276,12 +277,24 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER));
needsChatMemoryProviderBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
.equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
needsChatMemoryProviderBean = true;
}

if (Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER.toString().equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null));
needsRetrieverBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString()
.equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) },
null));
needsRetrieverBean = true;
}

syntheticBeanProducer.produce(configurator.done());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class Langchain4jDotNames {

static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatMemoryProviderSupplier.class);
static final DotName BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class);
static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.NoChatMemoryProviderSupplier.class);

Expand All @@ -54,4 +56,7 @@ public class Langchain4jDotNames {
static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsRetrieverSupplier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;

/**
* Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by
* using the class as a CDI bean.
* Under the hood Langchain4j's {@link AiServices#builder(Class)} is called
* while also providing the builder with the proper {@link ChatLanguageModel}, {@code tools} beans,
* {@link ChatMemory} or {@link ChatMemoryProvider}, {@link ModerationModel} and {@link Retriever}.
* while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional),
* {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist).
* <p>
* NOTE: The resulting CDI bean is {@link ApplicationScoped}.
*/
Expand All @@ -33,7 +32,8 @@
/**
* Configures the way to obtain the {@link ChatLanguageModel} to use.
* If not configured, the default CDI bean implementing the model is looked up.
* Such a bean provided automatically by extensions such {@code quarkus-langchain4j-openai} and
* Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai},
* {@code quarkus-langchain4j-azure-openai} or
* {@code quarkus-langchain4j-hugging-face}
*/
Class<? extends Supplier<ChatLanguageModel>> chatLanguageModelSupplier() default BeanChatLanguageModelSupplier.class;
Expand All @@ -47,14 +47,15 @@

/**
* Configures the way to obtain the {@link ChatMemoryProvider} to use.
* By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}.
* By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link ChatMemoryProvider} instance is needed, a custom implementation of
* {@link Supplier<ChatMemoryProvider>} needs to be provided.
* <p>
* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then
* set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier}
*/
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default NoChatMemoryProviderSupplier.class;
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class;

/**
* Configures the way to obtain the {@link Retriever} to use (when using RAG).
Expand All @@ -79,7 +80,8 @@ public ChatLanguageModel get() {
}

/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. If the bean does
* not exist, Quarkus will fail at build time.
*/
final class BeanChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

Expand All @@ -89,6 +91,18 @@ public ChatMemoryProvider get() {
}
}

/**
* Marker that is used to tell Quarkus to use the {@link ChatMemoryProvider} that the user has configured as a CDI bean.
* If no such bean exists, then no memory will be used.
*/
final class BeanIfExistsChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

@Override
public ChatMemoryProvider get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no chat memory should be used
*/
Expand All @@ -111,6 +125,18 @@ public Retriever<TextSegment> get() {
}
}

/**
* Marker that is used to tell Quarkus to use the {@link Retriever} that the user has configured as a CDI bean.
* If no such bean exists, then no retriever will be used.
*/
final class BeanIfExistsRetrieverSupplier implements Supplier<Retriever<TextSegment>> {

@Override
public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no retriever should be used
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import java.util.function.Function;
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.util.TypeLiteral;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
Expand All @@ -26,6 +28,11 @@
@Recorder
public class AiServicesRecorder {

private static final TypeLiteral<Instance<ChatMemoryProvider>> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {
};
private static final TypeLiteral<Instance<Retriever<TextSegment>>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {
};

// the key is the interface's class name
private static final Map<String, AiServiceClassCreateInfo> metadata = new HashMap<>();

Expand Down Expand Up @@ -93,6 +100,13 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
.equals(info.getChatMemoryProviderSupplierClassName())) {
quarkusAiServices.chatMemoryProvider(creationalContext.getInjectedReference(
ChatMemoryProvider.class));
} else if (RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class.getName()
.equals(info.getChatMemoryProviderSupplierClassName())) {
Instance<ChatMemoryProvider> instance = creationalContext
.getInjectedReference(CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.chatMemoryProvider(instance.get());
}
} else {
Supplier<? extends ChatMemoryProvider> supplier = (Supplier<? extends ChatMemoryProvider>) Thread
.currentThread().getContextClassLoader()
Expand All @@ -107,6 +121,13 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
.equals(info.getRetrieverSupplierClassName())) {
quarkusAiServices.retriever(creationalContext.getInjectedReference(new TypeLiteral<>() {
}));
} else if (RegisterAiService.BeanIfExistsRetrieverSupplier.class.getName()
.equals(info.getRetrieverSupplierClassName())) {
Instance<Retriever<TextSegment>> instance = creationalContext
.getInjectedReference(RETRIEVER_INSTANCE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.retriever(instance.get());
}
} else {
@SuppressWarnings("rawtypes")
Supplier<? extends Retriever> supplier = (Supplier<? extends Retriever>) Thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService( // <1>
chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class, // <2>
tools = EmailService.class // <3>
tools = EmailService.class // <2>
)
public interface MyAiService {

@SystemMessage("You are a professional poet") // <4>
@SystemMessage("You are a professional poet") // <3>
@UserMessage("""
Write a poem about {topic}. The poem should be {lines} lines long. Then send this poem by email. // <5>
Write a poem about {topic}. The poem should be {lines} lines long. Then send this poem by email. // <4>
""")
String writeAPoem(String topic, int lines); // <6>
}
String writeAPoem(String topic, int lines); // <5>
}
6 changes: 2 additions & 4 deletions docs/modules/ROOT/pages/agent-and-tools.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,7 @@ public class AssistantWithToolsResource {
}
}

@RegisterAiService(
tools = Calculator.class,
chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class) // <3>
@RegisterAiService(tools = Calculator.class) // <3>
public interface Assistant {

String chat(String userMessage);
Expand All @@ -181,7 +179,7 @@ public class AssistantWithToolsResource {
----
<1> Declare a CDI bean that provides three different tools
<2> Declare a CDI bean for providing a simple in-memory message store
<3> Register an AiService that responds to a user's request and has access to the calculator tools, while also being able to keep track of the session's messages using the CDI message store declared above
<3> Register an AiService that responds to a user's request and has access to the calculator tools. This service is also able to keep track of the session's messages using the CDI message store declared above.
<4> Declare an HTTP endpoint that retrieves the user's question via a query parameter and simply responds with chatbot's response

Now, if we ask the chatbot `What is the square root of the sum of the numbers of letters in the words "hello" and "world"` via:
Expand Down
43 changes: 18 additions & 25 deletions docs/modules/ROOT/pages/ai-services.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -168,53 +168,48 @@ public class MyChatModelSupplier implements Supplier<ChatLanguageModel> {

As LLMs are stateless, the memory — comprising the interaction context — must be exchanged each time. To prevent storing excessive messages, it's crucial to evict older messages.

The `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the memory provider:
The `chatMemoryProviderSupplier` attribute of the `@RegisterAiService` annotation enables configuring the memory provider. The default value of this annotation is `RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class`
which means that the `AiService` will use whatever `ChatMemoryProvider` bean is configured by the application, while falling back to no memory if no such bean exists.
An example of such a bean is:

[source,java]
----
@RegisterAiService(
chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class)
include::{examples-dir}/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java[]
----

It can be a class implementing `Supplier<ChatMemoryProvider>`, such as:
Notice that the messages are deleted when the scope terminates (as it will call the `close` method).

NOTE: It is recommended that the bean use the `@RequestScoped` scope or a scope not shared between users.

Users can provide their own custom `ChatMemoryProvider` for use in the AiService by implementing `Supplier<ChatMemoryProvider>`, such as:

[source,java]
----
include::{examples-dir}/io/quarkiverse/langchain4j/samples/MySmallMemoryProvider.java[]
----

In cases involving multiple users, ensure each user has a unique memory ID and pass this ID to the AI method:
and configuring the AiService as so:

[source,java]
----
String chat(@MemoryId int memoryId, @UserMessage String userMessage);
@RegisterAiService(
chatMemoryProviderSupplier = MySmallMemoryProvider.class)
----

Also, remember to clear out users to prevent memory issues.

TIP: For non-memory-reliant LLM interactions, you may skip memory configuration.

Alternatively, you can use the `BeanChatMemoryProviderSupplier` class to use a CDI bean as memory provider:

[source,java]
----
include::{examples-dir}/io/quarkiverse/langchain4j/samples/ChatMemoryBean.java[]
----
IMPORTANT: When using tools, you need a memory of at least 3 messages to cover the tools interaction.

Notice that the messages are deleted when the scope terminates (as it will call the `close` method).
=== @MemoryId

This bean is then referenced in the `@RegisterAiService` annotation using the `RegisterAiService.BeanChatMemoryProviderSupplier.class` value:
In cases involving multiple users, ensure each user has a unique memory ID and pass this ID to the AI method:

[source,java]
----
@RegisterAiService(
chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class
)
String chat(@MemoryId int memoryId, @UserMessage String userMessage);
----

NOTE: It is recommended that the bean use the `@RequestScoped` scope or a scope not shared between users.

IMPORTANT: When using tools, you need a memory of at least 3 messages to cover the tools interaction.
Also, remember to clear out users to prevent memory issues.

== Configuring Tools

Expand All @@ -237,9 +232,7 @@ The `@Tool` annotation can provide a description of the action, aiding the LLM i

[source,java]
----
@RegisterAiService(
chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class,
tools = {TransactionRepository.class, CustomerRepository.class })
@RegisterAiService(tools = {TransactionRepository.class, CustomerRepository.class })
----

IMPORTANT: Ensure you configure the memory provider when using tools.
Expand Down
9 changes: 4 additions & 5 deletions docs/modules/ROOT/pages/index.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ Once you've added the dependency and configuration, the next step involves creat
include::{examples-dir}/io/quarkiverse/langchain4j/samples/MyAiService.java[]
----
<1> The `@RegisterAiService` annotation registers the _AI service_.
<2> The `chatMemoryProviderSupplier` attribute specifies the _chat memory_ provider, managing how the LLM retains conversation history (the "context").
<3> The `tools` attribute defines the _tools_ the LLM can employ.
<2> The `tools` attribute defines the _tools_ the LLM can employ.
During interaction, the LLM can invoke these tools and reflect on their output.
<4> The `@SystemMessage` annotation registers a _system message_, setting the initial context or "scope".
<5> The `@UserMessage` annotation serves as the _prompt_.
<6> The method invokes the LLM, initiating an exchange between the LLM and the application, beginning with the system message and then the user message. Your application triggers this method and receives the response.
<3> The `@SystemMessage` annotation registers a _system message_, setting the initial context or "scope".
<4> The `@UserMessage` annotation serves as the _prompt_.
<5> The method invokes the LLM, initiating an exchange between the LLM and the application, beginning with the system message and then the user message. Your application triggers this method and receives the response.

== Advantages over vanilla Langchain4j

Expand Down
12 changes: 2 additions & 10 deletions docs/modules/ROOT/pages/retrievers.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,7 @@ Configure the maximum number of documents to retrieve (e.g., 20 in the example)
Make sure that the number of documents is not too high (or document too large).
More document you have, more data you are adding to the LLM context, and you may exceed the limit.

To use the retriever in your AI service, configure it as a CDI bean:

[source,java]
----
@RegisterAiService(retrieverSupplier = RegisterAiService.BeanRetrieverSupplier.class)
public interface MyAiService {
// ...
}
----
A retriever is used by default in your AI service, simply by having said retriever configured as a CDI bean,

Alternatively, implement a class implementing `Supplier<Retriever<TextSegment>>` if you prefer not to expose the retriever as a CDI bean.
Then, configure the `retrieverSupplier` attribute to point to your implementation.
Then, configure the `retrieverSupplier` attribute to point to your implementation.
Loading