Skip to content

Commit

Permalink
Merge pull request #145 from quarkiverse/#138
Browse files Browse the repository at this point in the history
Properly support Smallrye Fault Tolerance
  • Loading branch information
geoand authored Dec 14, 2023
2 parents f2f5876 + 9674a38 commit 3362655
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import jakarta.annotation.PreDestroy;
import jakarta.enterprise.context.Dependent;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
Expand All @@ -50,7 +53,6 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemoryRemovable;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceBeanDestroyer;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
Expand All @@ -59,6 +61,8 @@
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.InstanceHandle;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanBuildItem;
import io.quarkus.arc.deployment.GeneratedBeanGizmoAdaptor;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
Expand Down Expand Up @@ -311,16 +315,18 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
: null);

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(declarativeAiServiceClassInfo.name())
.configure(QuarkusAiServiceContext.class)
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName,
auditServiceClassSupplierName,
moderationModelSupplierClassName)))
.destroyer(DeclarativeAiServiceBeanDestroyer.class)
.setRuntimeInit()
.scope(bi.getCdiScope());
.addQualifier()
.annotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER).addValue("value", serviceClassName)
.done()
.scope(Dependent.class);
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.CHAT_MODEL));
needsChatModelBean = true;
Expand Down Expand Up @@ -392,6 +398,7 @@ public void handleAiServices(AiServicesRecorder recorder,
CombinedIndexBuildItem indexBuildItem,
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
BuildProducer<GeneratedClassBuildItem> generatedClassProducer,
BuildProducer<GeneratedBeanBuildItem> generatedBeanProducer,
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer,
BuildProducer<AiServicesMethodBuildItem> aiServicesMethodProducer,
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer,
Expand Down Expand Up @@ -476,7 +483,8 @@ public void handleAiServices(AiServicesRecorder recorder,

Map<String, AiServiceClassCreateInfo> perClassMetadata = new HashMap<>();
if (!ifacesForCreate.isEmpty()) {
ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
ClassOutput generatedClassOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
ClassOutput generatedBeanOutput = new GeneratedBeanGizmoAdaptor(generatedBeanProducer);
for (ClassInfo iface : ifacesForCreate) {
Set<MethodInfo> allMethods = new HashSet<>(iface.methods());
JandexUtil.getAllSuperinterfaces(iface, index).forEach(ci -> allMethods.addAll(ci.methods()));
Expand All @@ -497,13 +505,22 @@ public void handleAiServices(AiServicesRecorder recorder,
boolean isRegisteredService = registeredAiServiceClassNames.contains(ifaceName);

ClassCreator.Builder classCreatorBuilder = ClassCreator.builder()
.classOutput(classOutput)
.classOutput(isRegisteredService ? generatedBeanOutput : generatedClassOutput)
.className(implClassName)
.interfaces(ifaceName, ChatMemoryRemovable.class.getName());
if (isRegisteredService) {
classCreatorBuilder.interfaces(AutoCloseable.class);
}
try (ClassCreator classCreator = classCreatorBuilder.build()) {
if (isRegisteredService) {
// we need to make this a bean, so we need to add the proper scope annotation
ScopeInfo scopeInfo = declarativeAiServiceItems.stream()
.filter(bi -> bi.getServiceClassInfo().equals(iface))
.findFirst().orElseThrow(() -> new IllegalStateException(
"Unable to determine the CDI scope of " + iface))
.getCdiScope();
classCreator.addAnnotation(scopeInfo.getDotName().toString());
}

FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
Expand All @@ -516,37 +533,67 @@ public void handleAiServices(AiServicesRecorder recorder,
String methodId = createMethodId(methodInfo);
perMethodMetadata.put(methodId,
gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan));
MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
QuarkusAiServiceContext.class);
constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis());
constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0));
constructor.returnValue(null);

MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
mc.load(ifaceName),
mc.load(methodId));
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i));
{
MethodCreator ctor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
QuarkusAiServiceContext.class);
ctor.setModifiers(Modifier.PUBLIC);
ctor.addAnnotation(Inject.class);
ctor.getParameterAnnotations(0)
.addAnnotation(Langchain4jDotNames.QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER.toString())
.add("value", ifaceName);
ctor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, ctor.getThis());
ctor.writeInstanceField(contextField, ctor.getThis(),
ctor.getMethodParam(0));
ctor.returnValue(null);
}

ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);

ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
mc.returnValue(resultHandle);
{
MethodCreator noArgsCtor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V");
noArgsCtor.setModifiers(Modifier.PUBLIC);
noArgsCtor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, noArgsCtor.getThis());
noArgsCtor.writeInstanceField(contextField, noArgsCtor.getThis(), noArgsCtor.loadNull());
noArgsCtor.returnValue(null);
}

aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
{ // actual method we need to implement
MethodCreator mc = classCreator.getMethodCreator(MethodDescriptor.of(methodInfo));

// copy annotations
for (AnnotationInstance annotationInstance : methodInfo.declaredAnnotations()) {
// TODO: we need to review this
if (annotationInstance.name().toString()
.startsWith("org.eclipse.microprofile.faulttolerance")) {
mc.addAnnotation(annotationInstance);
}
}

ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
ResultHandle methodCreateInfoHandle = mc.invokeStaticMethod(RECORDER_METHOD_CREATE_INFO,
mc.load(ifaceName),
mc.load(methodId));
ResultHandle paramsHandle = mc.newArray(Object.class, methodInfo.parametersCount());
for (int i = 0; i < methodInfo.parametersCount(); i++) {
mc.writeArrayValue(paramsHandle, i, mc.getMethodParam(i));
}

ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class,
Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);

ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
mc.returnValue(resultHandle);

aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo));
}
}

if (isRegisteredService) {
MethodCreator mc = classCreator.getMethodCreator(
MethodDescriptor.ofMethod(implClassName, "close", void.class));
mc.addAnnotation(PreDestroy.class);
ResultHandle contextHandle = mc.readInstanceField(contextField, mc.getThis());
mc.invokeVirtualMethod(QUARKUS_AI_SERVICES_CONTEXT_CLOSE, contextHandle);
mc.returnVoid();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.quarkiverse.langchain4j.CreatedAware;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContextQualifier;

public class Langchain4jDotNames {
public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class);
Expand Down Expand Up @@ -67,4 +68,7 @@ public class Langchain4jDotNames {
static final DotName NO_MODERATION_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.NoModerationModelSupplier.class);

static final DotName QUARKUS_AI_SERVICE_CONTEXT_QUALIFIER = DotName.createSimple(
QuarkusAiServiceContextQualifier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
.loadClass(info.getServiceClassName());

QuarkusAiServiceContext aiServiceContext = new QuarkusAiServiceContext(serviceClass);
// we don't really care about QuarkusAiServices here, all we care about is that it
// properly populates QuarkusAiServiceContext which is what we are trying to construct
var quarkusAiServices = INSTANCE.create(aiServiceContext);

if (info.getLanguageModelSupplierClassName() != null) {
Expand Down Expand Up @@ -164,7 +166,7 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

return (T) quarkusAiServices.build();
return (T) aiServiceContext;
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
} catch (InvocationTargetException | NoSuchMethodException | IllegalAccessException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ public class QuarkusAiServiceContext extends AiServiceContext {

public AuditService auditService;

// needed by Arc
public QuarkusAiServiceContext() {
super(null);
}

public QuarkusAiServiceContext(Class<?> aiServiceClass) {
super(aiServiceClass);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import static java.lang.annotation.ElementType.PARAMETER;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;

import jakarta.enterprise.util.AnnotationLiteral;
import jakarta.inject.Qualifier;

@Qualifier
@Inherited
@Target({ PARAMETER })
@Retention(RUNTIME)
public @interface QuarkusAiServiceContextQualifier {

/**
* The name of class
*/
String value();

class Literal extends AnnotationLiteral<QuarkusAiServiceContextQualifier> implements QuarkusAiServiceContextQualifier {

public static Literal of(String value) {
return new Literal(value);
}

private final String value;

public Literal(String value) {
this.value = value;
}

@Override
public String value() {
return value;
}
}
}
4 changes: 4 additions & 0 deletions integration-tests/openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
<groupId>io.quarkus</groupId>
<artifactId>quarkus-micrometer</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.acme.example.openai.aiservices;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

import org.eclipse.microprofile.faulttolerance.Fallback;

import dev.langchain4j.service.SystemMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@Path("assistant-with-fallback")
public class AssistantResourceWithFallback {

private final Assistant assistant;

public AssistantResourceWithFallback(Assistant assistant) {
this.assistant = assistant;
}

@GET
public String get() {
return assistant.chat("test");
}

@RegisterAiService
interface Assistant {

@SystemMessage("""
Help me: {something}
""")
@Fallback(fallbackMethod = "fallback")
String chat(String message);

static String fallback(String message) {
return "This is a fallback message";
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.acme.example.openai.aiservices;

import static io.restassured.RestAssured.given;
import static org.hamcrest.CoreMatchers.equalTo;

import java.net.URL;

import org.junit.jupiter.api.Test;

import io.quarkus.test.common.http.TestHTTPEndpoint;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.test.junit.QuarkusTest;

@QuarkusTest
class AssistantResourceWithFallbackTest {

@TestHTTPEndpoint(AssistantResourceWithFallback.class)
@TestHTTPResource
URL url;

@Test
public void fallback() {
given()
.baseUri(url.toString())
.get()
.then()
.statusCode(200)
.body(equalTo("This is a fallback message"));
}
}
Loading

0 comments on commit 3362655

Please sign in to comment.