diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index e8df24a72..52b850ca9 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -14,6 +14,7 @@ import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW; import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.IGNORE; import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.OPTIONAL_DENY; +import static io.quarkus.arc.processor.DotNames.NAMED; import java.io.IOException; import java.io.InputStream; @@ -380,6 +381,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, String imageModelName = chatModelName; // TODO: should we have a separate setting for this? + AnnotationInstance namedAnno = declarativeAiServiceClassInfo.annotation(NAMED); + Optional beanName = Optional.empty(); + if (namedAnno != null) { + beanName = Optional.ofNullable(namedAnno.value().asString()); + } + declarativeAiServiceProducer.produce( new DeclarativeAiServiceBuildItem( declarativeAiServiceClassInfo, @@ -398,7 +405,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, chatModelName, moderationModelName, imageModelName, - toolProviderClassName)); + toolProviderClassName, + beanName)); } toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos)); @@ -705,6 +713,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, allToolProviders.add(toolProvider); } + bi.getBeanName().ifPresent(beanName -> configurator.named(beanName)); + configurator .addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE, new Type[] { ClassType.create(OutputGuardrail.class) }, null)) diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index 7c8ce19a4..6fd1fc997 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.deployment; import java.util.List; +import java.util.Optional; import org.jboss.jandex.ClassInfo; import org.jboss.jandex.DotName; @@ -30,6 +31,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final String chatModelName; private final String moderationModelName; private final String imageModelName; + private final Optional beanName; public DeclarativeAiServiceBuildItem( ClassInfo serviceClassInfo, @@ -48,7 +50,8 @@ public DeclarativeAiServiceBuildItem( String chatModelName, String moderationModelName, String imageModelName, - DotName toolProviderClassDotName) { + DotName toolProviderClassDotName, + Optional beanName) { this.serviceClassInfo = serviceClassInfo; this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName; this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName; @@ -66,6 +69,7 @@ public DeclarativeAiServiceBuildItem( this.moderationModelName = moderationModelName; this.imageModelName = imageModelName; this.toolProviderClassDotName = toolProviderClassDotName; + this.beanName = beanName; } public ClassInfo getServiceClassInfo() { @@ -135,4 +139,8 @@ public String getImageModelName() { public DotName getToolProviderClassDotName() { return toolProviderClassDotName; } + + public Optional getBeanName() { + return beanName; + } } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/NamedAiServicesAreResolvableByNameTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/NamedAiServicesAreResolvableByNameTest.java new file mode 100644 index 000000000..31797911c --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/NamedAiServicesAreResolvableByNameTest.java @@ -0,0 +1,56 @@ +package io.quarkiverse.langchain4j.test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import jakarta.enterprise.inject.spi.BeanManager; +import jakarta.inject.Inject; +import jakarta.inject.Named; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkus.test.QuarkusUnitTest; + +public class NamedAiServicesAreResolvableByNameTest { + + private static final String MY_NAMED_SERVICE_BEAN = "myNamedServiceBean"; + + @Inject + BeanManager beanManager; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyNamedService.class)); + + @Named(MY_NAMED_SERVICE_BEAN) + @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class) + interface MyNamedService { + @UserMessage("Dummy prompt for " + MY_NAMED_SERVICE_BEAN) + String chat(); + } + + @Singleton + public static class MyLanguageModel implements ChatLanguageModel { + @Override + public Response generate(List messages) { + return null; + } + } + + @Test + void namedAiServiceCouldBeResolvedByNameTest() { + assertEquals(1, beanManager.getBeans(MY_NAMED_SERVICE_BEAN).size()); + } +}