From 9674a38d0e0b5eaa4005ff0af9b93126f2bba246 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Wed, 13 Dec 2023 12:14:14 +0200 Subject: [PATCH] Properly support Smallrye Fault Tolerance Fixes: #138 --- .../deployment/AiServicesProcessor.java | 105 +++++++++++++----- .../deployment/Langchain4jDotNames.java | 4 + .../runtime/AiServicesRecorder.java | 4 +- .../aiservice/QuarkusAiServiceContext.java | 5 + .../QuarkusAiServiceContextQualifier.java | 41 +++++++ integration-tests/openai/pom.xml | 4 + .../AssistantResourceWithFallback.java | 39 +++++++ .../AssistantResourceWithFallbackTest.java | 30 +++++ .../aiservices/DeclarativeAiServicesTest.java | 26 ++++- 9 files changed, 222 insertions(+), 36 deletions(-) create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContextQualifier.java create mode 100644 integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithFallback.java create mode 100644 integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithFallbackTest.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 9b379695d..af0afcae5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -23,7 +23,10 @@ import java.util.function.Predicate; import java.util.stream.Collectors; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.Dependent; import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTarget; @@ -50,7 +53,6 @@ import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport; import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable; -import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer; import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper; import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; @@ -59,6 +61,8 @@ import io.quarkus.arc.ArcContainer; import io.quarkus.arc.InstanceHandle; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; +import io.quarkus.arc.deployment.GeneratedBeanBuildItem; +import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.arc.processor.BuiltinScope; @@ -311,16 +315,18 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, : null); SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem - .configure(declarativeAiServiceClassInfo.name()) + .configure(QuarkusAiServiceContext.class) .createWith(recorder.createDeclarativeAiService( new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName, toolClassNames, chatMemoryProviderSupplierClassName, retrieverSupplierClassName, auditServiceClassSupplierName, moderationModelSupplierClassName))) - .destroyer(DeclarativeAiServiceBeanDestroyer.class) .setRuntimeInit() - .scope(bi.getCdiScope()); + .addQualifier() + .annotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName) + .done() + .scope(Dependent.class); if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed? configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL)); needsChatModelBean = true; @@ -392,6 +398,7 @@ public void handleAiServices(AiServicesRecorder recorder, CombinedIndexBuildItem indexBuildItem, List declarativeAiServiceItems, BuildProducer generatedClassProducer, + BuildProducer generatedBeanProducer, BuildProducer reflectiveClassProducer, BuildProducer aiServicesMethodProducer, BuildProducer additionalBeanProducer, @@ -476,7 +483,8 @@ public void handleAiServices(AiServicesRecorder recorder, Map perClassMetadata = new HashMap<>(); if (!ifacesForCreate.isEmpty()) { - ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true); + ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true); + ClassOutput generatedBeanOutput = new GeneratedBeanGizmoAdaptor(generatedBeanProducer); for (ClassInfo iface : ifacesForCreate) { Set allMethods = new HashSet<>(iface.methods()); JandexUtil.getAllSuperinterfaces(iface, index).forEach(ci -> allMethods.addAll(ci.methods())); @@ -497,13 +505,22 @@ public void handleAiServices(AiServicesRecorder recorder, boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName); ClassCreator.Builder classCreatorBuilder = ClassCreator.builder() - .classOutput(classOutput) + .classOutput(isRegisteredService ? generatedBeanOutput : generatedClassOutput) .className(implClassName) .interfaces(ifaceName, ChatMemoryRemovable.class.getName()); if (isRegisteredService) { classCreatorBuilder.interfaces(AutoCloseable.class); } try (ClassCreator classCreator = classCreatorBuilder.build()) { + if (isRegisteredService) { + // we need to make this a bean, so we need to add the proper scope annotation + ScopeInfo scopeInfo = declarativeAiServiceItems.stream() + .filter(bi -> bi.getServiceClassInfo().equals(iface)) + .findFirst().orElseThrow(() -> new IllegalStateException( + "Unable to determine the CDI scope of " + iface)) + .getCdiScope(); + classCreator.addAnnotation(scopeInfo.getDotName().toString()); + } FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class) .setModifiers(Modifier.PRIVATE | Modifier.FINAL) @@ -516,37 +533,67 @@ public void handleAiServices(AiServicesRecorder recorder, String methodId = createMethodId(methodInfo); perMethodMetadata.put(methodId, gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan)); - MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V", - QuarkusAiServiceContext.class); - constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis()); - constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0)); - constructor.returnValue(null); - - MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo)); - ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis()); - ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO, - mc.load(ifaceName), - mc.load(methodId)); - ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount()); - for (int i = 0; i < methodInfo.parametersCount(); i++) { - mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i)); + { + MethodCreator ctor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V", + QuarkusAiServiceContext.class); + ctor.setModifiers(Modifier.PUBLIC); + ctor.addAnnotation(Inject.class); + ctor.getParameterAnnotations(0) + .addAnnotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER.toString()) + .add("value", ifaceName); + ctor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, ctor.getThis()); + ctor.writeInstanceField(contextField, ctor.getThis(), + ctor.getMethodParam(0)); + ctor.returnValue(null); } - ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName()); - ResultHandle inputHandle = mc.newInstance( - MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class, - QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class), - contextHandle, methodCreateInfoHandle, paramsHandle); - - ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle); - mc.returnValue(resultHandle); + { + MethodCreator noArgsCtor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V"); + noArgsCtor.setModifiers(Modifier.PUBLIC); + noArgsCtor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, noArgsCtor.getThis()); + noArgsCtor.writeInstanceField(contextField, noArgsCtor.getThis(), noArgsCtor.loadNull()); + noArgsCtor.returnValue(null); + } - aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo)); + { // actual method we need to implement + MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo)); + + // copy annotations + for (AnnotationInstance annotationInstance : methodInfo.declaredAnnotations()) { + // TODO: we need to review this + if (annotationInstance.name().toString() + .startsWith("org.eclipse.microprofile.faulttolerance")) { + mc.addAnnotation(annotationInstance); + } + } + + ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis()); + ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO, + mc.load(ifaceName), + mc.load(methodId)); + ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount()); + for (int i = 0; i < methodInfo.parametersCount(); i++) { + mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i)); + } + + ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName()); + ResultHandle inputHandle = mc.newInstance( + MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class, + QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, + Object[].class), + contextHandle, methodCreateInfoHandle, paramsHandle); + + ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle); + mc.returnValue(resultHandle); + + aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo)); + } } if (isRegisteredService) { MethodCreator mc = classCreator.getMethodCreator( MethodDescriptor.ofMethod(implClassName, "close", void.class)); + mc.addAnnotation(PreDestroy.class); ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis()); mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle); mc.returnVoid(); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java index 1fdf792c5..49fd961b5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java @@ -21,6 +21,7 @@ import io.quarkiverse.langchain4j.CreatedAware; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditService; +import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier; public class Langchain4jDotNames { public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class); @@ -67,4 +68,7 @@ public class Langchain4jDotNames { static final DotName NO_MODERATION_MODEL_SUPPLIER = DotName.createSimple( RegisterAiService.NoModerationModelSupplier.class); + static final DotName QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER = DotName.createSimple( + QuarkusAiServiceContextQualifier.class); + } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index d7bc6bd42..e4cf4b049 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -75,6 +75,8 @@ public T apply(SyntheticCreationalContext creationalContext) { .loadClass(info.getServiceClassName()); QuarkusAiServiceContext aiServiceContext = new QuarkusAiServiceContext(serviceClass); + // we don't really care about QuarkusAiServices here, all we care about is that it + // properly populates QuarkusAiServiceContext which is what we are trying to construct var quarkusAiServices = INSTANCE.create(aiServiceContext); if (info.getLanguageModelSupplierClassName() != null) { @@ -164,7 +166,7 @@ public T apply(SyntheticCreationalContext creationalContext) { } } - return (T) quarkusAiServices.build(); + return (T) aiServiceContext; } catch (ClassNotFoundException e) { throw new IllegalStateException(e); } catch (InvocationTargetException | NoSuchMethodException | IllegalAccessException diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java index 82c85ece5..ac97e1fdb 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java @@ -11,6 +11,11 @@ public class QuarkusAiServiceContext extends AiServiceContext { public AuditService auditService; + // needed by Arc + public QuarkusAiServiceContext() { + super(null); + } + public QuarkusAiServiceContext(Class aiServiceClass) { super(aiServiceClass); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContextQualifier.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContextQualifier.java new file mode 100644 index 000000000..7eb74b924 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContextQualifier.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import jakarta.enterprise.util.AnnotationLiteral; +import jakarta.inject.Qualifier; + +@Qualifier +@Inherited +@Target({ PARAMETER }) +@Retention(RUNTIME) +public @interface QuarkusAiServiceContextQualifier { + + /** + * The name of class + */ + String value(); + + class Literal extends AnnotationLiteral implements QuarkusAiServiceContextQualifier { + + public static Literal of(String value) { + return new Literal(value); + } + + private final String value; + + public Literal(String value) { + this.value = value; + } + + @Override + public String value() { + return value; + } + } +} diff --git a/integration-tests/openai/pom.xml b/integration-tests/openai/pom.xml index 19ac8e131..35dca9cbe 100644 --- a/integration-tests/openai/pom.xml +++ b/integration-tests/openai/pom.xml @@ -25,6 +25,10 @@ io.quarkus quarkus-micrometer + + io.quarkus + quarkus-smallrye-fault-tolerance + io.quarkus quarkus-junit5 diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithFallback.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithFallback.java new file mode 100644 index 000000000..e37dea10b --- /dev/null +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithFallback.java @@ -0,0 +1,39 @@ +package org.acme.example.openai.aiservices; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import org.eclipse.microprofile.faulttolerance.Fallback; + +import dev.langchain4j.service.SystemMessage; +import io.quarkiverse.langchain4j.RegisterAiService; + +@Path("assistant-with-fallback") +public class AssistantResourceWithFallback { + + private final Assistant assistant; + + public AssistantResourceWithFallback(Assistant assistant) { + this.assistant = assistant; + } + + @GET + public String get() { + return assistant.chat("test"); + } + + @RegisterAiService + interface Assistant { + + @SystemMessage(""" + Help me: {something} + """) + @Fallback(fallbackMethod = "fallback") + String chat(String message); + + static String fallback(String message) { + return "This is a fallback message"; + } + } + +} diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithFallbackTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithFallbackTest.java new file mode 100644 index 000000000..9c5098997 --- /dev/null +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithFallbackTest.java @@ -0,0 +1,30 @@ +package org.acme.example.openai.aiservices; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.CoreMatchers.equalTo; + +import java.net.URL; + +import org.junit.jupiter.api.Test; + +import io.quarkus.test.common.http.TestHTTPEndpoint; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +class AssistantResourceWithFallbackTest { + + @TestHTTPEndpoint(AssistantResourceWithFallback.class) + @TestHTTPResource + URL url; + + @Test + public void fallback() { + given() + .baseUri(url.toString()) + .get() + .then() + .statusCode(200) + .body(equalTo("This is a fallback message")); + } +} diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java index 8a25b94c7..5f9688aba 100644 --- a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java @@ -84,24 +84,38 @@ void setup() { wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub()); } - @RegisterAiService - interface Assistant { + interface AssistantBase { String chat(String message); } + @RegisterAiService + interface Assistant extends AssistantBase { + + String chat2(String message); + } + @Inject Assistant assistant; @Test @ActivateRequestContext - public void test_simple_instruction_with_single_argument_and_no_annotations() throws IOException { + public void test_simple_instruction_with_single_argument_and_no_annotations_from_super() throws IOException { String result = assistant.chat("Tell me a joke about developers"); assertThat(result).isNotBlank(); assertSingleRequestMessage(getRequestAsMap(), "Tell me a joke about developers"); } + @Test + @ActivateRequestContext + public void test_simple_instruction_with_single_argument_and_no_annotations_from_iface() throws IOException { + String result = assistant.chat2("Tell me a joke about developers"); + assertThat(result).isNotBlank(); + + assertSingleRequestMessage(getRequestAsMap(), "Tell me a joke about developers"); + } + enum Sentiment { POSITIVE, NEUTRAL, @@ -195,8 +209,8 @@ public void deleteMessages(Object memoryId) { } } - @RegisterAiService(tools = Calculator.class, chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class) - interface AssistantWithCalculator extends Assistant { + @RegisterAiService(tools = Calculator.class) + interface AssistantWithCalculator extends AssistantBase { } @@ -290,7 +304,7 @@ void should_execute_tool_then_answer() throws IOException { new MessageAssertUtils.MessageContent("function", "6.97070153193991E8"))); } - @RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class) + @RegisterAiService interface ChatWithSeparateMemoryForEachUser { String chat(@MemoryId int memoryId, @UserMessage String userMessage);