From 975092d035d7d946f4461b8c88f420c979c2e492 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 5 Dec 2023 09:46:44 +0200 Subject: [PATCH] Fail at build time when tools set and no memory provided Closes: #86 --- .../deployment/AiServicesProcessor.java | 24 ++++----- .../langchain4j/RegisterAiService.java | 7 ++- .../AiServiceWithToolsAndNoMemoryTest.java | 52 +++++++++++++++++++ 3 files changed, 69 insertions(+), 14 deletions(-) create mode 100644 openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index fa4d6be16..e7f10598a 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -170,7 +170,17 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, needChatModelBean = true; } - DotName chatMemoryProviderSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER; + List toolDotNames = Collections.emptyList(); + AnnotationValue toolsInstance = instance.value("tools"); + if (toolsInstance != null) { + toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name) + .collect(Collectors.toList()); + } + + // the default value depends on whether tools exists or not - if they do, then we require a ChatMemoryProvider bean + DotName chatMemoryProviderSupplierClassDotName = toolDotNames.isEmpty() + ? Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER + : Langchain4jDotNames.BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER; AnnotationValue chatMemoryProviderSupplierValue = instance.value("chatMemoryProviderSupplier"); if (chatMemoryProviderSupplierValue != null) { chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierValue.asClass().name(); @@ -184,18 +194,6 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } } - List toolDotNames = Collections.emptyList(); - AnnotationValue toolsInstance = instance.value("tools"); - if (toolsInstance != null) { - if (chatMemoryProviderSupplierClassDotName == null) { - throw new IllegalConfigurationException("Class '" + declarativeAiServiceClassInfo.name() - + "' which is annotated with @RegisterAiService has configured tools support, but no ChatMemoryProvider configuration is present. Please set up chatMemoryProvider in order to use tools. A ChatMemory that can hold at least 3 messages is required for the tools to work properly. While the LLM can technically execute a tool without chat memory, if it only receives the result of the tool's execution without the initial message from the user, it won't interpret the result properly."); - } - - toolDotNames = Arrays.stream(toolsInstance.asClassArray()).map(Type::name) - .collect(Collectors.toList()); - } - DotName retrieverSupplierClassDotName = Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER; AnnotationValue retrieverSupplierValue = instance.value("retrieverSupplier"); if (retrieverSupplierValue != null) { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 78e176759..ca85a6082 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -44,7 +44,8 @@ /** * Tool classes to use. All tools are expected to be CDI beans. *

- * NOTE: when this is used, {@code chatMemoryProviderSupplier} must NOT be set to {@link NoChatMemoryProviderSupplier}. + * NOTE: when this is used, either a {@link ChatMemoryProvider} bean must be present in the application, or a custom + * {@link Supplier} must be set. */ Class[] tools() default {}; @@ -57,6 +58,10 @@ *

* If the memory provider to use is exposed as a CDI bean exposing the type {@link ChatMemoryProvider}, then * set the value to {@link RegisterAiService.BeanChatMemoryProviderSupplier} + *

+ * NOTE: when {@link tools} is set, the default is changed to {@link BeanChatMemoryProviderSupplier} which means that a + * bean a {@link ChatMemoryProvider} bean must be present. The alternative in this case is to set a custom + * {@link Supplier}. */ Class> chatMemoryProviderSupplier() default BeanIfExistsChatMemoryProviderSupplier.class; diff --git a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java new file mode 100644 index 000000000..81c2759d0 --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/AiServiceWithToolsAndNoMemoryTest.java @@ -0,0 +1,52 @@ +package io.quarkiverse.langchain4j.openai.test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import jakarta.enterprise.inject.spi.DeploymentException; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.agent.tool.Tool; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class AiServiceWithToolsAndNoMemoryTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses()) + .assertException(t -> { + assertThat(t) + .isInstanceOf(DeploymentException.class) + .hasMessageContaining("ChatMemoryProvider"); + }); + + @RegisterAiService(tools = CustomTool.class) + interface Assistant { + + String chat(String input); + } + + @Singleton + static class CustomTool { + + @Tool + void doSomething() { + + } + } + + @Inject + Assistant assistant; + + @Test + void test() { + fail("Should not be called"); + } +}