Skip to content

Commit

Permalink
Fail at build time when tools set and no memory provided
Browse files Browse the repository at this point in the history
Closes: #86
  • Loading branch information
geoand committed Dec 5, 2023
1 parent bdbf7f0 commit 975092d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,17 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
needChatModelBean = true;
}

DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER;
List<DotName> toolDotNames = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
}

// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
DotName chatMemoryProviderSupplierClassDotName = toolDotNames.isEmpty()
? Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER
: Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
Expand All @@ -184,18 +194,6 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

List<DotName> toolDotNames = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
if (chatMemoryProviderSupplierClassDotName == null) {
throw new IllegalConfigurationException("Class '" + declarativeAiServiceClassInfo.name()
+ "' which is annotated with @RegisterAiService has configured tools support, but no ChatMemoryProvider configuration is present. Please set up chatMemoryProvider in order to use tools. A ChatMemory that can hold at least 3 messages is required for the tools to work properly. While the LLM can technically execute a tool without chat memory, if it only receives the result of the tool's execution without the initial message from the user, it won't interpret the result properly.");
}

toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
}

DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER;
AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier");
if (retrieverSupplierValue != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
/**
* Tool classes to use. All tools are expected to be CDI beans.
* <p>
* NOTE: when this is used, {@code chatMemoryProviderSupplier} must NOT be set to {@link NoChatMemoryProviderSupplier}.
* NOTE: when this is used, either a {@link ChatMemoryProvider} bean must be present in the application, or a custom
* {@link Supplier<ChatMemoryProvider>} must be set.
*/
Class<?>[] tools() default {};

Expand All @@ -57,6 +58,10 @@
* <p>
* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then
* set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier}
* <p>
* NOTE: when {@link tools} is set, the default is changed to {@link BeanChatMemoryProviderSupplier} which means that a
* bean a {@link ChatMemoryProvider} bean must be present. The alternative in this case is to set a custom
* {@link Supplier<ChatMemoryProvider>}.
*/
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.quarkiverse.langchain4j.openai.test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;

import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.agent.tool.Tool;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;

public class AiServiceWithToolsAndNoMemoryTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses())
.assertException(t -> {
assertThat(t)
.isInstanceOf(DeploymentException.class)
.hasMessageContaining("ChatMemoryProvider");
});

@RegisterAiService(tools = CustomTool.class)
interface Assistant {

String chat(String input);
}

@Singleton
static class CustomTool {

@Tool
void doSomething() {

}
}

@Inject
Assistant assistant;

@Test
void test() {
fail("Should not be called");
}
}

0 comments on commit 975092d

Please sign in to comment.