From 3766cfb69a429593dd3a990894803c34cee6ecc3 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 1 Oct 2024 18:11:12 +0300 Subject: [PATCH] Avoid duplicating info for AiService implementation constructors Fixes: #954 --- .../deployment/AiServicesProcessor.java | 43 ++++++++++--------- .../aiservices/DeclarativeAiServicesTest.java | 8 +++- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 1cfce01ff..c2b25faf3 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -916,6 +916,28 @@ public void handleAiServices( .setModifiers(Modifier.PRIVATE | Modifier.FINAL) .getFieldDescriptor(); + { + MethodCreator ctor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V", + QuarkusAiServiceContext.class); + ctor.setModifiers(Modifier.PUBLIC); + ctor.addAnnotation(Inject.class); + ctor.getParameterAnnotations(0) + .addAnnotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER.toString()) + .add("value", ifaceName); + ctor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, ctor.getThis()); + ctor.writeInstanceField(contextField, ctor.getThis(), + ctor.getMethodParam(0)); + ctor.returnValue(null); + } + + { + MethodCreator noArgsCtor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V"); + noArgsCtor.setModifiers(Modifier.PUBLIC); + noArgsCtor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, noArgsCtor.getThis()); + noArgsCtor.writeInstanceField(contextField, noArgsCtor.getThis(), noArgsCtor.loadNull()); + noArgsCtor.returnValue(null); + } + for (MethodInfo methodInfo : methodsToImplement) { // The implementation essentially gets the context and delegates to // MethodImplementationSupport#implement @@ -940,27 +962,6 @@ public void handleAiServices( .beanClassNames(methodCreateInfo.getToolClassNames().toArray(EMPTY_STRING_ARRAY))); } perMethodMetadata.put(methodId, methodCreateInfo); - { - MethodCreator ctor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V", - QuarkusAiServiceContext.class); - ctor.setModifiers(Modifier.PUBLIC); - ctor.addAnnotation(Inject.class); - ctor.getParameterAnnotations(0) - .addAnnotation(LangChain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER.toString()) - .add("value", ifaceName); - ctor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, ctor.getThis()); - ctor.writeInstanceField(contextField, ctor.getThis(), - ctor.getMethodParam(0)); - ctor.returnValue(null); - } - - { - MethodCreator noArgsCtor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V"); - noArgsCtor.setModifiers(Modifier.PUBLIC); - noArgsCtor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, noArgsCtor.getThis()); - noArgsCtor.writeInstanceField(contextField, noArgsCtor.getThis(), noArgsCtor.loadNull()); - noArgsCtor.returnValue(null); - } { // actual method we need to implement MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo)); diff --git a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java index 0a507c34f..053dfc04b 100644 --- a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java +++ b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/DeclarativeAiServicesTest.java @@ -47,8 +47,10 @@ import io.quarkiverse.langchain4j.ImageUrl; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest; +import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; import io.quarkiverse.langchain4j.testing.internal.WiremockAware; import io.quarkus.arc.Arc; +import io.quarkus.arc.ClientProxy; import io.quarkus.test.QuarkusUnitTest; public class DeclarativeAiServicesTest extends OpenAiBaseTest { @@ -483,10 +485,14 @@ interface ImageDescriber { ImageDescriber imageDescriber; @Test - public void test_image_describer() throws IOException { + public void test_image_describer() throws Exception { var imageUrl = "https://foo.bar"; doTestImageDescriber(() -> imageDescriber.describe("Java", imageUrl, "NOT_AN_IMAGE")); doTestImageDescriber(() -> imageDescriber.describe2("Java", Image.builder().url(imageUrl).build(), "NOT_AN_IMAGE")); + + // make sure the class is properly generated - see https://github.com/quarkiverse/quarkus-langchain4j/issues/954 + ImageDescriber unwrapped = ClientProxy.unwrap(imageDescriber); + unwrapped.getClass().getConstructor(QuarkusAiServiceContext.class).getAnnotation(Inject.class); } private void doTestImageDescriber(Supplier describerSupplier) throws IOException {