Skip to content

Commit

Permalink
Remove the need to explicitly configure retriever and memory beans fo…
Browse files Browse the repository at this point in the history
…r @RegisterAiService

Closes: #43
  • Loading branch information
geoand committed Nov 22, 2023
1 parent 02d4244 commit 8279813
Show file tree
Hide file tree
Showing 10 changed files with 332 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.stream.Collectors;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
Expand Down Expand Up @@ -156,7 +157,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
needChatModelBean = true;
}

DotName chatMemoryProviderSupplierClassDotName = null;
DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
Expand All @@ -182,7 +183,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
.collect(Collectors.toList());
}

DotName retrieverSupplierClassDotName = null;
DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER;
AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier");
if (retrieverSupplierValue != null) {
retrieverSupplierClassDotName = retrieverSupplierValue.asClass().name();
Expand Down Expand Up @@ -276,12 +277,24 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER.toString().equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER));
needsChatMemoryProviderBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
.equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
needsChatMemoryProviderBean = true;
}

if (Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER.toString().equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null));
needsRetrieverBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString()
.equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) },
null));
needsRetrieverBean = true;
}

syntheticBeanProducer.produce(configurator.done());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class Langchain4jDotNames {

static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatMemoryProviderSupplier.class);
static final DotName BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class);
static final DotName NO_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
RegisterAiService.NoChatMemoryProviderSupplier.class);

Expand All @@ -54,4 +56,7 @@ public class Langchain4jDotNames {
static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsRetrieverSupplier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;

/**
* Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by
* using the class as a CDI bean.
* Under the hood Langchain4j's {@link AiServices#builder(Class)} is called
* while also providing the builder with the proper {@link ChatLanguageModel}, {@code tools} beans,
* {@link ChatMemory} or {@link ChatMemoryProvider}, {@link ModerationModel} and {@link Retriever}.
* while also providing the builder with the proper {@link ChatLanguageModel} bean (mandatory), {@code tools} bean (optional),
* {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist).
* <p>
* NOTE: The resulting CDI bean is {@link ApplicationScoped}.
*/
Expand All @@ -33,7 +32,8 @@
/**
* Configures the way to obtain the {@link ChatLanguageModel} to use.
* If not configured, the default CDI bean implementing the model is looked up.
* Such a bean provided automatically by extensions such {@code quarkus-langchain4j-openai} and
* Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai},
* {@code quarkus-langchain4j-azure-openai} or
* {@code quarkus-langchain4j-hugging-face}
*/
Class<? extends Supplier<ChatLanguageModel>> chatLanguageModelSupplier() default BeanChatLanguageModelSupplier.class;
Expand All @@ -47,14 +47,15 @@

/**
* Configures the way to obtain the {@link ChatMemoryProvider} to use.
* By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}.
* By default, Quarkus will look for a CDI bean that implements {@link ChatMemoryProvider}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link ChatMemoryProvider} instance is needed, a custom implementation of
* {@link Supplier<ChatMemoryProvider>} needs to be provided.
* <p>
* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then
* set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier}
*/
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default NoChatMemoryProviderSupplier.class;
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class;

/**
* Configures the way to obtain the {@link Retriever} to use (when using RAG).
Expand All @@ -79,7 +80,8 @@ public ChatLanguageModel get() {
}

/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean. If the bean does
* not exist, Quarkus will fail at build time.
*/
final class BeanChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

Expand All @@ -89,6 +91,18 @@ public ChatMemoryProvider get() {
}
}

/**
* Marker that is used to tell Quarkus to use the {@link ChatMemoryProvider} that the user has configured as a CDI bean.
* If no such bean exists, then no memory will be used.
*/
final class BeanIfExistsChatMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {

@Override
public ChatMemoryProvider get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no chat memory should be used
*/
Expand All @@ -111,6 +125,18 @@ public Retriever<TextSegment> get() {
}
}

/**
* Marker that is used to tell Quarkus to use the {@link Retriever} that the user has configured as a CDI bean.
* If no such bean exists, then no retriever will be used.
*/
final class BeanIfExistsRetrieverSupplier implements Supplier<Retriever<TextSegment>> {

@Override
public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no retriever should be used
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import java.util.function.Function;
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.util.TypeLiteral;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
Expand All @@ -26,6 +28,11 @@
@Recorder
public class AiServicesRecorder {

private static final TypeLiteral<Instance<ChatMemoryProvider>> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {
};
private static final TypeLiteral<Instance<Retriever<TextSegment>>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {
};

// the key is the interface's class name
private static final Map<String, AiServiceClassCreateInfo> metadata = new HashMap<>();

Expand Down Expand Up @@ -93,6 +100,13 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
.equals(info.getChatMemoryProviderSupplierClassName())) {
quarkusAiServices.chatMemoryProvider(creationalContext.getInjectedReference(
ChatMemoryProvider.class));
} else if (RegisterAiService.BeanIfExistsChatMemoryProviderSupplier.class.getName()
.equals(info.getChatMemoryProviderSupplierClassName())) {
Instance<ChatMemoryProvider> instance = creationalContext
.getInjectedReference(CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.chatMemoryProvider(instance.get());
}
} else {
Supplier<? extends ChatMemoryProvider> supplier = (Supplier<? extends ChatMemoryProvider>) Thread
.currentThread().getContextClassLoader()
Expand All @@ -107,6 +121,13 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
.equals(info.getRetrieverSupplierClassName())) {
quarkusAiServices.retriever(creationalContext.getInjectedReference(new TypeLiteral<>() {
}));
} else if (RegisterAiService.BeanIfExistsRetrieverSupplier.class.getName()
.equals(info.getRetrieverSupplierClassName())) {
Instance<Retriever<TextSegment>> instance = creationalContext
.getInjectedReference(RETRIEVER_INSTANCE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.retriever(instance.get());
}
} else {
@SuppressWarnings("rawtypes")
Supplier<? extends Retriever> supplier = (Supplier<? extends Retriever>) Thread
Expand Down
Loading

0 comments on commit 8279813

Please sign in to comment.