diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index d4939f696..6e7e53738 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -51,6 +51,7 @@ import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.MethodImplementationSupport; +import io.quarkiverse.langchain4j.runtime.aiservice.MetricsProducingMethodImplementationSupport; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.UnremovableBeanBuildItem; import io.quarkus.builder.item.MultiBuildItem; @@ -62,6 +63,7 @@ import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.GeneratedClassBuildItem; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; +import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; import io.quarkus.gizmo.FieldDescriptor; @@ -69,6 +71,7 @@ import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; import io.quarkus.gizmo.ResultHandle; +import io.quarkus.runtime.metrics.MetricsFactory; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class AiServicesProcessor { @@ -76,6 +79,7 @@ public class AiServicesProcessor { private static final Logger log = Logger.getLogger(AiServicesProcessor.class); private static final DotName V = DotName.createSimple(V.class); + public static final DotName MICROMETER_TIMED = DotName.createSimple("io.micrometer.core.annotation.Timed"); private static final String DEFAULT_DELIMITER = "\n"; private static final Predicate IS_METHOD_PARAMETER_ANNOTATION = ai -> ai.target() .kind() == AnnotationTarget.Kind.METHOD_PARAMETER; @@ -88,6 +92,9 @@ public class AiServicesProcessor { "getAiServiceMethodCreateInfo", AiServiceMethodCreateInfo.class, String.class, String.class); private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(MethodImplementationSupport.class, "implement", Object.class, AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class); + private static final MethodDescriptor SUPPORT_WITH_METRICS_IMPLEMENT = MethodDescriptor.ofMethod( + MetricsProducingMethodImplementationSupport.class, + "implement", Object.class, AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class); @BuildStep public void nativeSupport(CombinedIndexBuildItem indexBuildItem, @@ -321,7 +328,8 @@ public void handleAiServices(AiServicesRecorder recorder, List declarativeAiServiceItems, BuildProducer generatedClassProducer, BuildProducer reflectiveClassProducer, - BuildProducer aiServicesMethodProducer) { + BuildProducer aiServicesMethodProducer, + Optional metricsCapability) { IndexView index = indexBuildItem.getIndex(); @@ -386,6 +394,9 @@ public void handleAiServices(AiServicesRecorder recorder, ifacesForCreate.add(classInfo); } + var addMicrometerMetrics = metricsCapability.isPresent() + && metricsCapability.get().metricsSupported(MetricsFactory.MICROMETER); + Map perClassMetadata = new HashMap<>(); if (!ifacesForCreate.isEmpty()) { ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true); @@ -420,7 +431,7 @@ public void handleAiServices(AiServicesRecorder recorder, // MethodImplementationSupport#implement String methodId = createMethodId(methodInfo); - perMethodMetadata.put(methodId, gatherMethodMetadata(methodInfo)); + perMethodMetadata.put(methodId, gatherMethodMetadata(methodInfo, addMicrometerMetrics)); MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V", AiServiceContext.class); constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis()); @@ -437,7 +448,8 @@ public void handleAiServices(AiServicesRecorder recorder, mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i)); } - ResultHandle resultHandle = mc.invokeStaticMethod(SUPPORT_IMPLEMENT, contextHandle, + ResultHandle resultHandle = mc.invokeStaticMethod( + addMicrometerMetrics ? SUPPORT_WITH_METRICS_IMPLEMENT : SUPPORT_IMPLEMENT, contextHandle, methodCreateInfoHandle, paramsHandle); mc.returnValue(resultHandle); @@ -486,7 +498,7 @@ private static void addCreatedAware(IndexView index, Set detectedForCrea } } - private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method) { + private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolean addMicrometerMetrics) { if (method.returnType().kind() == Type.Kind.VOID) { throw illegalConfiguration("Return type of method '%s' cannot be void", method); } @@ -501,9 +513,10 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method) { AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams, returnType); Optional memoryIdParamPosition = gatherMemoryIdParamName(method); + Optional metricsInfo = gatherMetricsInfo(method, addMicrometerMetrics); return new AiServiceMethodCreateInfo(systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration, - returnType); + returnType, metricsInfo); } private List gatherTemplateParamInfo(List params) { @@ -620,6 +633,66 @@ private AiServiceMethodCreateInfo.UserMessageInfo gatherUserMessageInfo(MethodIn } } + private Optional gatherMetricsInfo(MethodInfo method, + boolean addMicrometerMetrics) { + if (!addMicrometerMetrics) { + return Optional.empty(); + } + + String name = defaultAiServiceMetricName(method); + + AnnotationInstance timedInstance = method.annotation(MICROMETER_TIMED); + if (timedInstance == null) { + timedInstance = method.declaringClass().declaredAnnotation(MICROMETER_TIMED); + } + + if (timedInstance == null) { + // we default to having all AiServices being timed + return Optional.of(new AiServiceMethodCreateInfo.MetricsInfo.Builder(name).build()); + } + + AnnotationValue nameValue = timedInstance.value(); + if (nameValue != null) { + String nameStr = nameValue.asString(); + if (nameStr != null && !nameStr.isEmpty()) { + name = nameStr; + } + } + + var builder = new AiServiceMethodCreateInfo.MetricsInfo.Builder(name); + + AnnotationValue extraTagsValue = timedInstance.value("extraTags"); + if (extraTagsValue != null) { + builder.setExtraTags(extraTagsValue.asStringArray()); + } + + AnnotationValue longTaskValue = timedInstance.value("longTask"); + if (longTaskValue != null) { + builder.setLongTask(longTaskValue.asBoolean()); + } + + AnnotationValue percentilesValue = timedInstance.value("percentiles"); + if (percentilesValue != null) { + builder.setPercentiles(percentilesValue.asDoubleArray()); + } + + AnnotationValue histogramValue = timedInstance.value("histogram"); + if (histogramValue != null) { + builder.setHistogram(histogramValue.asBoolean()); + } + + AnnotationValue descriptionValue = timedInstance.value("description"); + if (descriptionValue != null) { + builder.setDescription(descriptionValue.asString()); + } + + return Optional.of(builder.build()); + } + + private String defaultAiServiceMetricName(MethodInfo method) { + return "langchain4j.aiservices." + method.declaringClass().name().withoutPackagePrefix() + "." + method.name(); + } + private static class TemplateParameterInfo { private final int position; private final String name; diff --git a/core/runtime/pom.xml b/core/runtime/pom.xml index f7210695f..a0f99cb51 100644 --- a/core/runtime/pom.xml +++ b/core/runtime/pom.xml @@ -21,6 +21,11 @@ io.quarkus quarkus-qute + + io.micrometer + micrometer-core + true + dev.langchain4j diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index 6266ba6a3..78e176759 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -24,6 +24,9 @@ * {@link ChatMemoryProvider} and {@link Retriever} beans (which by default are configured if such beans exist). *

* NOTE: The resulting CDI bean is {@link ApplicationScoped}. + *

+ * NOTE: When the application also contains the {@code quarkus-micrometer} extension, metrics are automatically generated + * for the method invocations. */ @Retention(RUNTIME) @Target(ElementType.TYPE) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index e7a347f1b..5125d5372 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -15,15 +15,18 @@ public class AiServiceMethodCreateInfo { private final boolean requiresModeration; private final Class returnType; + private final Optional metricsInfo; + @RecordableConstructor public AiServiceMethodCreateInfo(Optional systemMessageInfo, UserMessageInfo userMessageInfo, Optional memoryIdParamPosition, - boolean requiresModeration, Class returnType) { + boolean requiresModeration, Class returnType, Optional metricsInfo) { this.systemMessageInfo = systemMessageInfo; this.userMessageInfo = userMessageInfo; this.memoryIdParamPosition = memoryIdParamPosition; this.requiresModeration = requiresModeration; this.returnType = returnType; + this.metricsInfo = metricsInfo; } public Optional getSystemMessageInfo() { @@ -46,6 +49,10 @@ public Class getReturnType() { return returnType; } + public Optional getMetricsInfo() { + return metricsInfo; + } + public static class UserMessageInfo { private final Optional template; private final Optional paramPosition; @@ -107,4 +114,95 @@ public Map getNameToParamPosition() { return nameToParamPosition; } } + + public static class MetricsInfo { + private final String name; + private final boolean longTask; + private final String[] extraTags; + private final double[] percentiles; + private final boolean histogram; + private final String description; + + public MetricsInfo(String name) { + this(name, false, null, null, false, null); + } + + @RecordableConstructor + public MetricsInfo(String name, boolean longTask, String[] extraTags, double[] percentiles, boolean histogram, + String description) { + this.name = name; + this.longTask = longTask; + this.extraTags = extraTags; + this.percentiles = percentiles; + this.histogram = histogram; + this.description = description; + } + + public String getName() { + return name; + } + + public boolean isLongTask() { + return longTask; + } + + public String[] getExtraTags() { + return extraTags; + } + + public double[] getPercentiles() { + return percentiles; + } + + public boolean isHistogram() { + return histogram; + } + + public String getDescription() { + return description; + } + + public static class Builder { + private final String name; + private boolean longTask = false; + private String[] extraTags = {}; + private double[] percentiles = {}; + private boolean histogram = false; + private String description = ""; + + public Builder(String name) { + this.name = name; + } + + public Builder setLongTask(boolean longTask) { + this.longTask = longTask; + return this; + } + + public Builder setExtraTags(String[] extraTags) { + this.extraTags = extraTags; + return this; + } + + public Builder setPercentiles(double[] percentiles) { + this.percentiles = percentiles; + return this; + } + + public Builder setHistogram(boolean histogram) { + this.histogram = histogram; + return this; + } + + public Builder setDescription(String description) { + this.description = description; + return this; + } + + public AiServiceMethodCreateInfo.MetricsInfo build() { + return new AiServiceMethodCreateInfo.MetricsInfo(name, longTask, extraTags, percentiles, histogram, + description); + } + } + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/MetricsProducingMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/MetricsProducingMethodImplementationSupport.java new file mode 100644 index 000000000..9ae6fb82d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/MetricsProducingMethodImplementationSupport.java @@ -0,0 +1,52 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import java.util.Optional; +import java.util.function.Supplier; + +import dev.langchain4j.service.AiServiceContext; +import io.micrometer.core.instrument.LongTaskTimer; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; + +/** + * When micrometer metrics are enabled, this is used to record how long calls take + */ +@SuppressWarnings("unused") // the methods are used in generated code +public class MetricsProducingMethodImplementationSupport { + + public static Object implement(AiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) { + Optional metricsInfoOpt = createInfo.getMetricsInfo(); + if (metricsInfoOpt.isPresent()) { + AiServiceMethodCreateInfo.MetricsInfo metricsInfo = metricsInfoOpt.get(); + if (metricsInfo.isLongTask()) { + LongTaskTimer timer = LongTaskTimer.builder(metricsInfo.getName()) + .description(metricsInfo.getDescription()) + .publishPercentiles(metricsInfo.getPercentiles()) + .publishPercentileHistogram(metricsInfo.isHistogram()) + .tags(metricsInfo.getExtraTags()) + .register(Metrics.globalRegistry); + return timer.record(new Supplier() { + @Override + public Object get() { + return MethodImplementationSupport.implement(context, createInfo, methodArgs); + } + }); + } else { + Timer timer = Timer.builder(metricsInfo.getName()) + .description(metricsInfo.getDescription()) + .publishPercentiles(metricsInfo.getPercentiles()) + .publishPercentileHistogram(metricsInfo.isHistogram()) + .tags(metricsInfo.getExtraTags()) + .register(Metrics.globalRegistry); + return timer.record(new Supplier() { + @Override + public Object get() { + return MethodImplementationSupport.implement(context, createInfo, methodArgs); + } + }); + } + } else { + return MethodImplementationSupport.implement(context, createInfo, methodArgs); + } + } +} diff --git a/integration-tests/openai/pom.xml b/integration-tests/openai/pom.xml index c10b4920c..19ac8e131 100644 --- a/integration-tests/openai/pom.xml +++ b/integration-tests/openai/pom.xml @@ -21,6 +21,10 @@ quarkus-langchain4j-openai ${project.version} + + io.quarkus + quarkus-micrometer + io.quarkus quarkus-junit5 diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithMetrics.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithMetrics.java new file mode 100644 index 000000000..00e1a0452 --- /dev/null +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantResourceWithMetrics.java @@ -0,0 +1,53 @@ +package org.acme.example.openai.aiservices; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; + +import io.micrometer.core.annotation.Timed; +import io.quarkiverse.langchain4j.RegisterAiService; + +@Path("assistant-with-metrics") +public class AssistantResourceWithMetrics { + + private final Assistant1 assistant1; + private final Assistant2 assistant2; + + public AssistantResourceWithMetrics(Assistant1 assistant1, Assistant2 assistant2) { + this.assistant1 = assistant1; + this.assistant2 = assistant2; + } + + @GET + @Path("a1") + public String assistant1() { + return assistant1.chat("test"); + } + + @GET + @Path("a2") + public String assistant2() { + return assistant2.chat("test"); + } + + @GET + @Path("a2c2") + public String assistant2Chat2() { + return assistant2.chat2("test"); + } + + @RegisterAiService + interface Assistant1 { + + String chat(String message); + } + + @RegisterAiService + @Timed(extraTags = { "key", "value" }) + interface Assistant2 { + + String chat(String message); + + @Timed(value = "a2c2", description = "Assistant2#chat2") + String chat2(String message); + } +} diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithMetricsTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithMetricsTest.java new file mode 100644 index 000000000..0834089a7 --- /dev/null +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithMetricsTest.java @@ -0,0 +1,84 @@ +package org.acme.example.openai.aiservices; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.containsString; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.URL; +import java.util.Collection; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.quarkus.test.common.http.TestHTTPEndpoint; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +class AssistantResourceWithMetricsTest { + + @TestHTTPEndpoint(AssistantResourceWithMetrics.class) + @TestHTTPResource + URL url; + + @Inject + MeterRegistry registry; + + @BeforeAll + static void addSimpleRegistry() { + Metrics.globalRegistry.add(new SimpleMeterRegistry()); + } + + @Test + public void noTimedAnnotations() throws InterruptedException { + given() + .baseUri(url.toString()) + .get("a1") + .then() + .statusCode(200) + .body(containsString("MockGPT")); + + waitForMeters(registry.find("langchain4j.aiservices.AssistantResourceWithMetrics$Assistant1.chat").timers(), 1); + } + + @Test + public void timedAnnotationOnClass() throws InterruptedException { + given() + .baseUri(url.toString()) + .get("a2") + .then() + .statusCode(200) + .body(containsString("MockGPT")); + + waitForMeters(registry.find("langchain4j.aiservices.AssistantResourceWithMetrics$Assistant2.chat").tag("key", "value") + .timers(), 1); + } + + @Test + public void timedAnnotationOnMethod() throws InterruptedException { + given() + .baseUri(url.toString()) + .get("a2c2") + .then() + .statusCode(200) + .body(containsString("MockGPT")); + + waitForMeters(registry.find("a2c2").timers(), 1); + } + + public void waitForMeters(Collection collection, int count) throws InterruptedException { + int i = 0; + do { + Thread.sleep(10); + } while (collection.size() < count && i++ < 5); + + if (i > 5) { + fail("Unable to find the requested metrics"); + } + } +}