From 21a6dc6a62d6d0afd71311cd3d7aee73bf7c1b41 Mon Sep 17 00:00:00 2001
From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com>
Date: Fri, 30 Aug 2024 18:41:28 +0200
Subject: [PATCH] add new streamingChatLanguageModelSupplier property to
RegisterAiService to support custom StreamingChatLanguageModel suppliers
Closes #842
Signed-off-by: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com>
---
core/deployment/pom.xml | 10 +++
.../deployment/AiServicesProcessor.java | 67 ++++++++++----
.../DeclarativeAiServiceBuildItem.java | 19 ++--
.../deployment/LangChain4jDotNames.java | 3 +
...BlockingChatLanguageModelSupplierTest.java | 62 +++++++++++++
...BlockingChatLanguageModelSupplierTest.java | 88 +++++++++++++++++++
...treamingChatLanguageModelSupplierTest.java | 67 ++++++++++++++
.../langchain4j/RegisterAiService.java | 23 +++++
.../runtime/AiServicesRecorder.java | 28 ++++--
.../DeclarativeAiServiceCreateInfo.java | 4 +-
10 files changed, 338 insertions(+), 33 deletions(-)
create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/BlockingChatLanguageModelSupplierTest.java
create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingAndBlockingChatLanguageModelSupplierTest.java
create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingChatLanguageModelSupplierTest.java
diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml
index 999088a2d..f62c830e1 100644
--- a/core/deployment/pom.xml
+++ b/core/deployment/pom.xml
@@ -69,6 +69,16 @@
${wiremock.version}
test
+
+ io.rest-assured
+ rest-assured
+ test
+
+
+ io.quarkus
+ quarkus-resteasy-reactive
+ test
+
dev.langchain4j
langchain4j-embeddings-all-minilm-l6-v2-q
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
index e895b0636..e014e3ccc 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
@@ -27,6 +27,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
+import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -209,22 +210,23 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
ClassInfo declarativeAiServiceClassInfo = instance.target().asClass();
- DotName chatLanguageModelSupplierClassDotName = null;
- AnnotationValue chatLanguageModelSupplierValue = instance.value("chatLanguageModelSupplier");
- if (chatLanguageModelSupplierValue != null) {
- chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierValue.asClass().name();
- if (chatLanguageModelSupplierClassDotName.equals(LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER)) { // this is the case where the
- // default was set, so we just
- // ignore it
- chatLanguageModelSupplierClassDotName = null;
- } else {
- validateSupplierAndRegisterForReflection(chatLanguageModelSupplierClassDotName, index,
- reflectiveClassProducer);
- }
- }
+ DotName chatLanguageModelSupplierClassDotName = getSupplierDotName(instance.value("chatLanguageModelSupplier"),
+ LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER,
+ supplierDotName -> validateSupplierAndRegisterForReflection(
+ supplierDotName,
+ index,
+ reflectiveClassProducer));
+
+ DotName streamingChatLanguageModelSupplierClassDotName = getSupplierDotName(
+ instance.value("streamingChatLanguageModelSupplier"),
+ LangChain4jDotNames.BEAN_STREAMING_CHAT_MODEL_SUPPLIER,
+ supplierDotName -> validateSupplierAndRegisterForReflection(
+ supplierDotName,
+ index,
+ reflectiveClassProducer));
String chatModelName = NamedConfigUtil.DEFAULT_NAME;
- if (chatLanguageModelSupplierClassDotName == null) {
+ if (chatLanguageModelSupplierClassDotName == null && streamingChatLanguageModelSupplierClassDotName == null) {
AnnotationValue modelNameValue = instance.value("modelName");
if (modelNameValue != null) {
String modelNameValueStr = modelNameValue.asString();
@@ -337,6 +339,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
+ streamingChatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverClassDotName,
@@ -359,6 +362,23 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}
+ private DotName getSupplierDotName(
+ AnnotationValue instanceAnnotation,
+ DotName supplierDotName,
+ Consumer validator) {
+ DotName dotName = null;
+ if (instanceAnnotation != null) {
+ dotName = instanceAnnotation.asClass().name();
+ if (dotName.equals(supplierDotName)) {
+ // this is the case where the default was set, so we just ignore it
+ dotName = null;
+ } else {
+ validator.accept(dotName);
+ }
+ }
+ return dotName;
+ }
+
private void validateSupplierAndRegisterForReflection(DotName supplierDotName, IndexView index,
BuildProducer producer) {
ClassInfo classInfo = index.getClassByName(supplierDotName);
@@ -396,8 +416,12 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
ClassInfo declarativeAiServiceClassInfo = bi.getServiceClassInfo();
String serviceClassName = declarativeAiServiceClassInfo.name().toString();
- String chatLanguageModelSupplierClassName = (bi.getLanguageModelSupplierClassDotName() != null
- ? bi.getLanguageModelSupplierClassDotName().toString()
+ String chatLanguageModelSupplierClassName = (bi.getChatLanguageModelSupplierClassDotName() != null
+ ? bi.getChatLanguageModelSupplierClassDotName().toString()
+ : null);
+
+ String streamingChatLanguageModelSupplierClassName = (bi.getStreamingChatLanguageModelSupplierClassDotName() != null
+ ? bi.getStreamingChatLanguageModelSupplierClassDotName().toString()
: null);
List toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
@@ -428,6 +452,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
// determine whether the method returns Multi
boolean injectStreamingChatModelBean = false;
+ // currently in one class either streaming or blocking model are supported, but not both
+ // if we want to support it, the injectStreamingChatModelBean needs to be recorded per injection point
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (!LangChain4jDotNames.MULTI.equals(method.returnType().name())) {
continue;
@@ -460,7 +486,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
.configure(QuarkusAiServiceContext.class)
.forceApplicationClass()
.createWith(recorder.createDeclarativeAiService(
- new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
+ new DeclarativeAiServiceCreateInfo(
+ serviceClassName,
+ chatLanguageModelSupplierClassName,
+ streamingChatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName,
@@ -476,7 +505,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
.done()
.scope(Dependent.class);
- if ((chatLanguageModelSupplierClassName == null) && !selectedChatModelProvider.isEmpty()) {
+ boolean hasChatModelSupplier = chatLanguageModelSupplierClassName == null
+ && streamingChatLanguageModelSupplierClassName == null;
+ if (hasChatModelSupplier && !selectedChatModelProvider.isEmpty()) {
if (NamedConfigUtil.isDefault(chatModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL));
if (injectStreamingChatModelBean) {
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java
index b0f030ade..9cbae7b24 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java
@@ -13,7 +13,8 @@
public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final ClassInfo serviceClassInfo;
- private final DotName languageModelSupplierClassDotName;
+ private final DotName chatLanguageModelSupplierClassDotName;
+ private final DotName streamingChatLanguageModelSupplierClassDotName;
private final List toolDotNames;
private final DotName chatMemoryProviderSupplierClassDotName;
@@ -27,7 +28,10 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final String chatModelName;
private final String moderationModelName;
- public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
+ public DeclarativeAiServiceBuildItem(
+ ClassInfo serviceClassInfo,
+ DotName chatLanguageModelSupplierClassDotName,
+ DotName streamingChatLanguageModelSupplierClassDotName,
List toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverClassDotName,
@@ -40,7 +44,8 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag
String chatModelName,
String moderationModelName) {
this.serviceClassInfo = serviceClassInfo;
- this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
+ this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
+ this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverClassDotName = retrieverClassDotName;
@@ -58,8 +63,12 @@ public ClassInfo getServiceClassInfo() {
return serviceClassInfo;
}
- public DotName getLanguageModelSupplierClassDotName() {
- return languageModelSupplierClassDotName;
+ public DotName getChatLanguageModelSupplierClassDotName() {
+ return chatLanguageModelSupplierClassDotName;
+ }
+
+ public DotName getStreamingChatLanguageModelSupplierClassDotName() {
+ return streamingChatLanguageModelSupplierClassDotName;
}
public List getToolDotNames() {
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java
index d3243ccc8..b73a3ad03 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/LangChain4jDotNames.java
@@ -57,6 +57,9 @@ public class LangChain4jDotNames {
static final DotName BEAN_CHAT_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatLanguageModelSupplier.class);
+ static final DotName BEAN_STREAMING_CHAT_MODEL_SUPPLIER = DotName.createSimple(
+ RegisterAiService.BeanStreamingChatLanguageModelSupplier.class);
+
static final DotName CHAT_MEMORY_PROVIDER = DotName.createSimple(ChatMemoryProvider.class);
static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/BlockingChatLanguageModelSupplierTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/BlockingChatLanguageModelSupplierTest.java
new file mode 100644
index 000000000..4a1314ca7
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/BlockingChatLanguageModelSupplierTest.java
@@ -0,0 +1,62 @@
+package io.quarkiverse.langchain4j.test;
+
+import static io.restassured.RestAssured.get;
+import static org.hamcrest.Matchers.equalTo;
+
+import java.util.function.Supplier;
+
+import jakarta.ws.rs.GET;
+import jakarta.ws.rs.Path;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkus.test.QuarkusUnitTest;
+
+public class BlockingChatLanguageModelSupplierTest {
+ @Path("/test")
+ static class MyResource {
+
+ private final MyService service;
+
+ MyResource(MyService service) {
+ this.service = service;
+ }
+
+ @GET
+ public String blocking() {
+ return service.chat("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
+ }
+ }
+
+ public static class MyModelSupplier implements Supplier {
+ @Override
+ public ChatLanguageModel get() {
+ return (messages) -> new Response<>(new AiMessage("42"));
+ }
+ }
+
+ @RegisterAiService(chatLanguageModelSupplier = MyModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
+ interface MyService {
+ String chat(String msg);
+ }
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
+ .addClasses(MyResource.class, MyService.class, MyModelSupplier.class));
+
+ @Test
+ public void testCall() {
+ get("test")
+ .then()
+ .statusCode(200)
+ .body(equalTo("42"));
+ }
+}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingAndBlockingChatLanguageModelSupplierTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingAndBlockingChatLanguageModelSupplierTest.java
new file mode 100644
index 000000000..69679da50
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingAndBlockingChatLanguageModelSupplierTest.java
@@ -0,0 +1,88 @@
+package io.quarkiverse.langchain4j.test;
+
+import static io.restassured.RestAssured.get;
+import static org.hamcrest.Matchers.equalTo;
+
+import java.util.function.Supplier;
+
+import jakarta.ws.rs.GET;
+import jakarta.ws.rs.Path;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkus.test.QuarkusUnitTest;
+import io.smallrye.mutiny.Multi;
+
+public class StreamingAndBlockingChatLanguageModelSupplierTest {
+ @Path("/test")
+ static class MyResource {
+
+ private final MyService service;
+
+ MyResource(MyService service) {
+ this.service = service;
+ }
+
+ @GET
+ @Path("/blocking")
+ public String blocking() {
+ return service.blocking("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
+ }
+
+ @GET
+ @Path("/streaming")
+ public Multi streaming() {
+ return service.streaming("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
+ }
+ }
+
+ public static class MyModelSupplier implements Supplier {
+ @Override
+ public ChatLanguageModel get() {
+ return (messages) -> new Response<>(new AiMessage("42"));
+ }
+ }
+
+ public static class MyStreamingModelSupplier implements Supplier {
+ @Override
+ public StreamingChatLanguageModel get() {
+ return (messages, handler) -> {
+ handler.onNext("4");
+ handler.onNext("2");
+ handler.onComplete(new Response<>(new AiMessage("")));
+ };
+ }
+ }
+
+ @RegisterAiService(chatLanguageModelSupplier = MyModelSupplier.class, streamingChatLanguageModelSupplier = MyStreamingModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
+ interface MyService {
+ Multi streaming(String msg);
+
+ String blocking(String msg);
+ }
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
+ .addClasses(MyResource.class, MyService.class, MyModelSupplier.class, MyStreamingModelSupplier.class));
+
+ @Test
+ public void testCalls() {
+ get("test/blocking")
+ .then()
+ .statusCode(200)
+ .body(equalTo("42"));
+ get("test/streaming")
+ .then()
+ .statusCode(200)
+ .body(equalTo("42"));
+ }
+}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingChatLanguageModelSupplierTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingChatLanguageModelSupplierTest.java
new file mode 100644
index 000000000..a7ec4a457
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/StreamingChatLanguageModelSupplierTest.java
@@ -0,0 +1,67 @@
+package io.quarkiverse.langchain4j.test;
+
+import static io.restassured.RestAssured.get;
+import static org.hamcrest.Matchers.equalTo;
+
+import java.util.function.Supplier;
+
+import jakarta.ws.rs.GET;
+import jakarta.ws.rs.Path;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.output.Response;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkus.test.QuarkusUnitTest;
+import io.smallrye.mutiny.Multi;
+
+public class StreamingChatLanguageModelSupplierTest {
+ @Path("/test")
+ static class MyResource {
+
+ private final MyService service;
+
+ MyResource(MyService service) {
+ this.service = service;
+ }
+
+ @GET
+ public Multi blocking() {
+ return service.chat("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
+ }
+ }
+
+ public static class MyModelSupplier implements Supplier {
+ @Override
+ public StreamingChatLanguageModel get() {
+ return (messages, handler) -> {
+ handler.onNext("4");
+ handler.onNext("2");
+ handler.onComplete(new Response<>(new AiMessage("")));
+ };
+ }
+ }
+
+ @RegisterAiService(streamingChatLanguageModelSupplier = MyModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
+ interface MyService {
+ Multi chat(String msg);
+ }
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
+ .addClasses(MyResource.class, MyService.class, MyModelSupplier.class));
+
+ @Test
+ public void testCall() {
+ get("test")
+ .then()
+ .statusCode(200)
+ .body(equalTo("42"));
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
index 53fa9677e..1436ded61 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java
@@ -13,6 +13,7 @@
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.retriever.Retriever;
@@ -41,6 +42,15 @@
@Target(ElementType.TYPE)
public @interface RegisterAiService {
+ /**
+ * Configures the way to obtain the {@link StreamingChatLanguageModel} to use.
+ * If not configured, the default CDI bean implementing the model is looked up.
+ * Such a bean provided automatically by extensions such as {@code quarkus-langchain4j-openai},
+ * {@code quarkus-langchain4j-azure-openai} or
+ * {@code quarkus-langchain4j-hugging-face}
+ */
+ Class extends Supplier> streamingChatLanguageModelSupplier() default BeanStreamingChatLanguageModelSupplier.class;
+
/**
* Configures the way to obtain the {@link ChatLanguageModel} to use.
* If not configured, the default CDI bean implementing the model is looked up.
@@ -142,6 +152,19 @@ public ChatLanguageModel get() {
}
}
+ /**
+ * Marker that is used to tell Quarkus to use the {@link StreamingChatLanguageModel} that has been configured as a CDI bean
+ * by * any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
+ * {@code quarkus-langchain4j-hugging-face}).
+ */
+ final class BeanStreamingChatLanguageModelSupplier implements Supplier {
+
+ @Override
+ public StreamingChatLanguageModel get() {
+ throw new UnsupportedOperationException("should never be called");
+ }
+ }
+
/**
* Marker that is used to tell Quarkus to use the retriever that the user has configured as a CDI bean.
* Be default, Quarkus configures an {@link ChatMemoryProvider} by using an {@link InMemoryChatMemoryStore}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
index be058d7df..a63d56733 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java
@@ -111,19 +111,22 @@ public T apply(SyntheticCreationalContext creationalContext) {
// properly populates QuarkusAiServiceContext which is what we are trying to construct
var quarkusAiServices = INSTANCE.create(aiServiceContext);
- if (info.languageModelSupplierClassName() != null) {
- Supplier extends ChatLanguageModel> supplier = (Supplier extends ChatLanguageModel>) Thread
- .currentThread().getContextClassLoader().loadClass(info.languageModelSupplierClassName())
- .getConstructor().newInstance();
-
- quarkusAiServices.chatLanguageModel(supplier.get());
-
+ if (info.languageModelSupplierClassName() != null
+ || info.streamingChatLanguageModelSupplierClassName() != null) {
+ if (info.languageModelSupplierClassName() != null) {
+ Supplier extends ChatLanguageModel> supplier = createSupplier(
+ info.languageModelSupplierClassName());
+ quarkusAiServices.chatLanguageModel(supplier.get());
+ }
+ if (info.streamingChatLanguageModelSupplierClassName() != null) {
+ Supplier extends StreamingChatLanguageModel> supplier = createSupplier(
+ info.streamingChatLanguageModelSupplierClassName());
+ quarkusAiServices.streamingChatLanguageModel(supplier.get());
+ }
} else {
-
if (NamedConfigUtil.isDefault(info.chatModelName())) {
quarkusAiServices
.chatLanguageModel(creationalContext.getInjectedReference(ChatLanguageModel.class));
-
if (info.needsStreamingChatModel()) {
quarkusAiServices
.streamingChatLanguageModel(
@@ -252,4 +255,11 @@ public T apply(SyntheticCreationalContext creationalContext) {
}
};
}
+
+ private static Supplier createSupplier(String className) throws InstantiationException, IllegalAccessException,
+ InvocationTargetException, NoSuchMethodException, ClassNotFoundException {
+ return (Supplier) Thread
+ .currentThread().getContextClassLoader().loadClass(className)
+ .getConstructor().newInstance();
+ }
}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java
index d1ea5c34e..2e8241551 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java
@@ -2,8 +2,10 @@
import java.util.List;
-public record DeclarativeAiServiceCreateInfo(String serviceClassName,
+public record DeclarativeAiServiceCreateInfo(
+ String serviceClassName,
String languageModelSupplierClassName,
+ String streamingChatLanguageModelSupplierClassName,
List toolsClassNames,
String chatMemoryProviderSupplierClassName,
String retrieverClassName,