Skip to content

Commit

Permalink
Refactor retriever use in @RegisterAiService
Browse files Browse the repository at this point in the history
With this change a retriever is no longer
added by default to an AiService,
but needs to be configured explicitly.

Closes: #184
  • Loading branch information
geoand committed Jan 17, 2024
1 parent 29db5c1 commit f2398bd
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.service.ServiceOutputParser.outputFormatInstructions;
import static io.quarkiverse.langchain4j.deployment.ExceptionUtil.illegalConfigurationForMethod;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.NO_RETRIEVER;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -209,12 +210,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER;
AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier");
if (retrieverSupplierValue != null) {
retrieverSupplierClassDotName = retrieverSupplierValue.asClass().name();
if (!retrieverSupplierClassDotName.equals(Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER)) {
validateSupplierAndRegisterForReflection(retrieverSupplierClassDotName, index, reflectiveClassProducer);
DotName retrieverClassDotName = null;
AnnotationValue retrieverValue = instance.value("retriever");
if (retrieverValue != null) {
retrieverClassDotName = retrieverValue.asClass().name();
if (NO_RETRIEVER.equals(retrieverClassDotName)) {
retrieverClassDotName = null;
}
}

Expand Down Expand Up @@ -247,7 +248,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName,
retrieverClassDotName,
auditServiceSupplierClassName,
moderationModelSupplierClassName,
cdiScope));
Expand Down Expand Up @@ -306,8 +307,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getChatMemoryProviderSupplierClassDotName().toString()
: null;

String retrieverSupplierClassName = bi.getRetrieverSupplierClassDotName() != null
? bi.getRetrieverSupplierClassDotName().toString()
String retrieverClassName = bi.getRetrieverClassDotName() != null
? bi.getRetrieverClassDotName().toString()
: null;

String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null
Expand All @@ -323,7 +324,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName,
retrieverClassName,
auditServiceClassSupplierName,
moderationModelSupplierClassName)))
.setRuntimeInit()
Expand All @@ -349,16 +350,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsChatMemoryProviderBean = true;
}

if (Langchain4jDotNames.BEAN_RETRIEVER_SUPPLIER.toString().equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null));
needsRetrieverBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString()
.equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) },
null));
if (retrieverClassName != null) {
configurator.addInjectionPoint(ClassType.create(retrieverClassName));
needsRetrieverBean = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final List<DotName> toolDotNames;

private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName retrieverClassDotName;
private final DotName auditServiceClassSupplierDotName;
private final DotName moderationModelSupplierDotName;
private final ScopeInfo cdiScope;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName,
DotName retrieverClassDotName,
DotName auditServiceClassSupplierDotName,
DotName moderationModelSupplierDotName,
ScopeInfo cdiScope) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.retrieverClassDotName = retrieverClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
this.moderationModelSupplierDotName = moderationModelSupplierDotName;
this.cdiScope = cdiScope;
Expand All @@ -56,8 +56,8 @@ public DotName getChatMemoryProviderSupplierClassDotName() {
return chatMemoryProviderSupplierClassDotName;
}

public DotName getRetrieverSupplierClassDotName() {
return retrieverSupplierClassDotName;
public DotName getRetrieverClassDotName() {
return retrieverClassDotName;
}

public DotName getAuditServiceClassSupplierDotName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import org.jboss.jandex.DotName;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
Expand Down Expand Up @@ -51,16 +50,11 @@ public class Langchain4jDotNames {
RegisterAiService.BeanChatMemoryProviderSupplier.class);

static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class);
static final DotName NO_RETRIEVER = DotName.createSimple(
RegisterAiService.NoRetriever.class);

static final DotName AUDIT_SERVICE = DotName.createSimple(AuditService.class);

static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsAuditServiceSupplier.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import java.util.List;
import java.util.function.Supplier;

import dev.langchain4j.data.segment.TextSegment;
Expand Down Expand Up @@ -73,13 +74,10 @@
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanChatMemoryProviderSupplier.class;

/**
* Configures the way to obtain the {@link Retriever} to use (when using RAG).
* Configures the way to obtain the {@link Retriever} to use (when using RAG). All tools are expected to be CDI beans
* By default, no supplier is used.
* If a CDI bean of type {@link ChatMemory} is needed, the value should be {@link BeanRetrieverSupplier}.
* If an arbitrary {@link ChatMemory} instance is needed, a custom implementation of {@link Supplier<ChatMemory>}
* needs to be provided.
*/
Class<? extends Supplier<Retriever<TextSegment>>> retrieverSupplier() default NoRetrieverSupplier.class;
Class<? extends Retriever<TextSegment>> retriever() default NoRetriever.class;

/**
* Configures the way to obtain the {@link AuditService} to use.
Expand Down Expand Up @@ -127,36 +125,13 @@ public ChatMemoryProvider get() {
}
}

/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean
*/
final class BeanRetrieverSupplier implements Supplier<Retriever<TextSegment>> {

@Override
public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the {@link Retriever} that the user has configured as a CDI bean.
* If no such bean exists, then no retriever will be used.
*/
final class BeanIfExistsRetrieverSupplier implements Supplier<Retriever<TextSegment>> {

@Override
public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no retriever should be used
*/
final class NoRetrieverSupplier implements Supplier<Retriever<TextSegment>> {
final class NoRetriever implements Retriever<TextSegment> {

@Override
public Retriever<TextSegment> get() {
public List<TextSegment> findRelevant(String text) {
throw new UnsupportedOperationException("should never be called");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,25 +113,9 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

if (info.getRetrieverSupplierClassName() != null) {
if (RegisterAiService.BeanRetrieverSupplier.class.getName()
.equals(info.getRetrieverSupplierClassName())) {
quarkusAiServices.retriever(creationalContext.getInjectedReference(new TypeLiteral<>() {
}));
} else if (RegisterAiService.BeanIfExistsRetrieverSupplier.class.getName()
.equals(info.getRetrieverSupplierClassName())) {
Instance<Retriever<TextSegment>> instance = creationalContext
.getInjectedReference(RETRIEVER_INSTANCE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.retriever(instance.get());
}
} else {
@SuppressWarnings("rawtypes")
Supplier<? extends Retriever> supplier = (Supplier<? extends Retriever>) Thread
.currentThread().getContextClassLoader().loadClass(info.getRetrieverSupplierClassName())
.getConstructor().newInstance();
quarkusAiServices.retriever(supplier.get());
}
if (info.getRetrieverClassName() != null) {
quarkusAiServices.retriever((Retriever<TextSegment>) creationalContext.getInjectedReference(
Thread.currentThread().getContextClassLoader().loadClass(info.getRetrieverClassName())));
}

if (info.getAuditServiceClassSupplierName() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@ public class DeclarativeAiServiceCreateInfo {
private final String languageModelSupplierClassName;
private final List<String> toolsClassNames;
private final String chatMemoryProviderSupplierClassName;
private final String retrieverSupplierClassName;
private final String retrieverClassName;

private final String auditServiceClassSupplierName;
private final String moderationModelSupplierClassName;

@RecordableConstructor
public DeclarativeAiServiceCreateInfo(String serviceClassName, String languageModelSupplierClassName,
List<String> toolsClassNames, String chatMemoryProviderSupplierClassName,
String retrieverSupplierClassName,
String retrieverClassName,
String auditServiceClassSupplierName,
String moderationModelSupplierClassName) {
this.serviceClassName = serviceClassName;
this.languageModelSupplierClassName = languageModelSupplierClassName;
this.toolsClassNames = toolsClassNames;
this.chatMemoryProviderSupplierClassName = chatMemoryProviderSupplierClassName;
this.retrieverSupplierClassName = retrieverSupplierClassName;
this.retrieverClassName = retrieverClassName;
this.auditServiceClassSupplierName = auditServiceClassSupplierName;
this.moderationModelSupplierClassName = moderationModelSupplierClassName;
}
Expand All @@ -46,8 +46,8 @@ public String getChatMemoryProviderSupplierClassName() {
return chatMemoryProviderSupplierClassName;
}

public String getRetrieverSupplierClassName() {
return retrieverSupplierClassName;
public String getRetrieverClassName() {
return retrieverClassName;
}

public String getAuditServiceClassSupplierName() {
Expand Down
6 changes: 2 additions & 4 deletions docs/modules/ROOT/pages/retrievers.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,5 @@ Configure the maximum number of documents to retrieve (e.g., 20 in the example)
Make sure that the number of documents is not too high (or document too large).
More document you have, more data you are adding to the LLM context, and you may exceed the limit.

A retriever is used by default in your AI service, simply by having said retriever configured as a CDI bean,

Alternatively, implement a class implementing `Supplier<Retriever<TextSegment>>` if you prefer not to expose the retriever as a CDI bean.
Then, configure the `retrieverSupplier` attribute to point to your implementation.
An AI service does not use a retriever by default, one needs to be configured explicitly via the `retriever` property of `@RegisterAiService` and the configured
retriever is expected to be a CDI bean.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
Expand Down Expand Up @@ -116,6 +118,34 @@ public void test_simple_instruction_with_single_argument_and_no_annotations_from
assertSingleRequestMessage(getRequestAsMap(), "Tell me a joke about developers");
}

@Singleton
public static class DummyRetriever implements Retriever<TextSegment> {

@Override
public List<TextSegment> findRelevant(String text) {
return List.of(TextSegment.from("dummy"));
}
}

@RegisterAiService(retriever = DummyRetriever.class)
interface AssistantWithRetriever {

String chat(String message);
}

@Inject
AssistantWithRetriever assistantWithRetriever;

@Test
@ActivateRequestContext
public void test_simple_instruction_with_retriever() throws IOException {
String result = assistantWithRetriever.chat("Tell me a joke about developers");
assertThat(result).isNotBlank();

assertSingleRequestMessage(getRequestAsMap(),
"Tell me a joke about developers\n\nHere is some information that might be useful for answering:\n\ndummy");
}

enum Sentiment {
POSITIVE,
NEUTRAL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService
@RegisterAiService(retriever = RetrieverExample.class)
@Singleton // this is singleton because WebSockets currently never closes the scope
public interface Bot {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService
@RegisterAiService(retriever = RetrieverExample.class)
@Singleton // this is singleton because WebSockets currently never closes the scope
public interface MovieMuse {

Expand Down

0 comments on commit f2398bd

Please sign in to comment.