Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail at build time when tools set and no memory provided #92

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,17 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
needChatModelBean = true;
}

DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER;
List<DotName> toolDotNames = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
}

// the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean
DotName chatMemoryProviderSupplierClassDotName = toolDotNames.isEmpty()
? Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER
: Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER;
AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier");
if (chatMemoryProviderSupplierValue != null) {
chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name();
Expand All @@ -184,18 +194,6 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

List<DotName> toolDotNames = Collections.emptyList();
AnnotationValue toolsInstance = instance.value("tools");
if (toolsInstance != null) {
if (chatMemoryProviderSupplierClassDotName == null) {
throw new IllegalConfigurationException("Class '" + declarativeAiServiceClassInfo.name()
+ "' which is annotated with @RegisterAiService has configured tools support, but no ChatMemoryProvider configuration is present. Please set up chatMemoryProvider in order to use tools. A ChatMemory that can hold at least 3 messages is required for the tools to work properly. While the LLM can technically execute a tool without chat memory, if it only receives the result of the tool's execution without the initial message from the user, it won't interpret the result properly.");
}

toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name)
.collect(Collectors.toList());
}

DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER;
AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier");
if (retrieverSupplierValue != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
/**
* Tool classes to use. All tools are expected to be CDI beans.
* <p>
* NOTE: when this is used, {@code chatMemoryProviderSupplier} must NOT be set to {@link NoChatMemoryProviderSupplier}.
* NOTE: when this is used, either a {@link ChatMemoryProvider} bean must be present in the application, or a custom
* {@link Supplier<ChatMemoryProvider>} must be set.
*/
Class<?>[] tools() default {};

Expand All @@ -57,6 +58,10 @@
* <p>
* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then
* set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier}
* <p>
* NOTE: when {@link tools} is set, the default is changed to {@link BeanChatMemoryProviderSupplier} which means that a
* bean a {@link ChatMemoryProvider} bean must be present. The alternative in this case is to set a custom
* {@link Supplier<ChatMemoryProvider>}.
*/
Class<? extends Supplier<ChatMemoryProvider>> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.quarkiverse.langchain4j.openai.test;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;

import jakarta.enterprise.inject.spi.DeploymentException;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.agent.tool.Tool;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;

public class AiServiceWithToolsAndNoMemoryTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses())
.assertException(t -> {
assertThat(t)
.isInstanceOf(DeploymentException.class)
.hasMessageContaining("ChatMemoryProvider");
});

@RegisterAiService(tools = CustomTool.class)
interface Assistant {

String chat(String input);
}

@Singleton
static class CustomTool {

@Tool
void doSomething() {

}
}

@Inject
Assistant assistant;

@Test
void test() {
fail("Should not be called");
}
}
Loading