Skip to content

Commit

Permalink
Merge pull request #844 from dastrobu/streamingChatLanguageModelSupplier
Browse files Browse the repository at this point in the history
Add `streamingChatLanguageModelSupplier` property to `@RegisterAiService`
  • Loading branch information
geoand authored Sep 2, 2024
2 parents 7d1119d + 21a6dc6 commit 664087b
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 33 deletions.
10 changes: 10 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@
<version>${wiremock.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.rest-assured</groupId>
<artifactId>rest-assured</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -208,22 +209,23 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
ClassInfo declarativeAiServiceClassInfo = instance.target().asClass();

DotName chatLanguageModelSupplierClassDotName = null;
AnnotationValue chatLanguageModelSupplierValue = instance.value("chatLanguageModelSupplier");
if (chatLanguageModelSupplierValue != null) {
chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierValue.asClass().name();
if (chatLanguageModelSupplierClassDotName.equals(LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER)) { // this is the case where the
// default was set, so we just
// ignore it
chatLanguageModelSupplierClassDotName = null;
} else {
validateSupplierAndRegisterForReflection(chatLanguageModelSupplierClassDotName, index,
reflectiveClassProducer);
}
}
DotName chatLanguageModelSupplierClassDotName = getSupplierDotName(instance.value("chatLanguageModelSupplier"),
LangChain4jDotNames.BEAN_CHAT_MODEL_SUPPLIER,
supplierDotName -> validateSupplierAndRegisterForReflection(
supplierDotName,
index,
reflectiveClassProducer));

DotName streamingChatLanguageModelSupplierClassDotName = getSupplierDotName(
instance.value("streamingChatLanguageModelSupplier"),
LangChain4jDotNames.BEAN_STREAMING_CHAT_MODEL_SUPPLIER,
supplierDotName -> validateSupplierAndRegisterForReflection(
supplierDotName,
index,
reflectiveClassProducer));

String chatModelName = NamedConfigUtil.DEFAULT_NAME;
if (chatLanguageModelSupplierClassDotName == null) {
if (chatLanguageModelSupplierClassDotName == null && streamingChatLanguageModelSupplierClassDotName == null) {
AnnotationValue modelNameValue = instance.value("modelName");
if (modelNameValue != null) {
String modelNameValueStr = modelNameValue.asString();
Expand Down Expand Up @@ -336,6 +338,7 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
streamingChatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverClassDotName,
Expand All @@ -358,6 +361,23 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

private DotName getSupplierDotName(
AnnotationValue instanceAnnotation,
DotName supplierDotName,
Consumer<DotName> validator) {
DotName dotName = null;
if (instanceAnnotation != null) {
dotName = instanceAnnotation.asClass().name();
if (dotName.equals(supplierDotName)) {
// this is the case where the default was set, so we just ignore it
dotName = null;
} else {
validator.accept(dotName);
}
}
return dotName;
}

private void validateSupplierAndRegisterForReflection(DotName supplierDotName, IndexView index,
BuildProducer<ReflectiveClassBuildItem> producer) {
ClassInfo classInfo = index.getClassByName(supplierDotName);
Expand Down Expand Up @@ -395,8 +415,12 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
ClassInfo declarativeAiServiceClassInfo = bi.getServiceClassInfo();
String serviceClassName = declarativeAiServiceClassInfo.name().toString();

String chatLanguageModelSupplierClassName = (bi.getLanguageModelSupplierClassDotName() != null
? bi.getLanguageModelSupplierClassDotName().toString()
String chatLanguageModelSupplierClassName = (bi.getChatLanguageModelSupplierClassDotName() != null
? bi.getChatLanguageModelSupplierClassDotName().toString()
: null);

String streamingChatLanguageModelSupplierClassName = (bi.getStreamingChatLanguageModelSupplierClassDotName() != null
? bi.getStreamingChatLanguageModelSupplierClassDotName().toString()
: null);

List<String> toolClassNames = bi.getToolDotNames().stream().map(DotName::toString).collect(Collectors.toList());
Expand Down Expand Up @@ -427,6 +451,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,

// determine whether the method returns Multi<String>
boolean injectStreamingChatModelBean = false;
// currently in one class either streaming or blocking model are supported, but not both
// if we want to support it, the injectStreamingChatModelBean needs to be recorded per injection point
for (MethodInfo method : declarativeAiServiceClassInfo.methods()) {
if (!LangChain4jDotNames.MULTI.equals(method.returnType().name())) {
continue;
Expand Down Expand Up @@ -459,7 +485,10 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
.configure(QuarkusAiServiceContext.class)
.forceApplicationClass()
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
new DeclarativeAiServiceCreateInfo(
serviceClassName,
chatLanguageModelSupplierClassName,
streamingChatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName, retrieverClassName,
retrievalAugmentorSupplierClassName,
auditServiceClassSupplierName,
Expand All @@ -475,7 +504,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
.done()
.scope(Dependent.class);

if ((chatLanguageModelSupplierClassName == null) && !selectedChatModelProvider.isEmpty()) {
boolean hasChatModelSupplier = chatLanguageModelSupplierClassName == null
&& streamingChatLanguageModelSupplierClassName == null;
if (hasChatModelSupplier && !selectedChatModelProvider.isEmpty()) {
if (NamedConfigUtil.isDefault(chatModelName)) {
configurator.addInjectionPoint(ClassType.create(LangChain4jDotNames.CHAT_MODEL));
if (injectStreamingChatModelBean) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {

private final ClassInfo serviceClassInfo;
private final DotName languageModelSupplierClassDotName;
private final DotName chatLanguageModelSupplierClassDotName;
private final DotName streamingChatLanguageModelSupplierClassDotName;
private final List<DotName> toolDotNames;

private final DotName chatMemoryProviderSupplierClassDotName;
Expand All @@ -27,7 +28,10 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final String chatModelName;
private final String moderationModelName;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
public DeclarativeAiServiceBuildItem(
ClassInfo serviceClassInfo,
DotName chatLanguageModelSupplierClassDotName,
DotName streamingChatLanguageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverClassDotName,
Expand All @@ -40,7 +44,8 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag
String chatModelName,
String moderationModelName) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverClassDotName = retrieverClassDotName;
Expand All @@ -58,8 +63,12 @@ public ClassInfo getServiceClassInfo() {
return serviceClassInfo;
}

public DotName getLanguageModelSupplierClassDotName() {
return languageModelSupplierClassDotName;
public DotName getChatLanguageModelSupplierClassDotName() {
return chatLanguageModelSupplierClassDotName;
}

public DotName getStreamingChatLanguageModelSupplierClassDotName() {
return streamingChatLanguageModelSupplierClassDotName;
}

public List<DotName> getToolDotNames() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public class LangChain4jDotNames {
static final DotName BEAN_CHAT_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanChatLanguageModelSupplier.class);

static final DotName BEAN_STREAMING_CHAT_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanStreamingChatLanguageModelSupplier.class);

static final DotName CHAT_MEMORY_PROVIDER = DotName.createSimple(ChatMemoryProvider.class);

static final DotName BEAN_CHAT_MEMORY_PROVIDER_SUPPLIER = DotName.createSimple(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package io.quarkiverse.langchain4j.test;

import static io.restassured.RestAssured.get;
import static org.hamcrest.Matchers.equalTo;

import java.util.function.Supplier;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;

public class BlockingChatLanguageModelSupplierTest {
@Path("/test")
static class MyResource {

private final MyService service;

MyResource(MyService service) {
this.service = service;
}

@GET
public String blocking() {
return service.chat("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
}
}

public static class MyModelSupplier implements Supplier<ChatLanguageModel> {
@Override
public ChatLanguageModel get() {
return (messages) -> new Response<>(new AiMessage("42"));
}
}

@RegisterAiService(chatLanguageModelSupplier = MyModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
interface MyService {
String chat(String msg);
}

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyResource.class, MyService.class, MyModelSupplier.class));

@Test
public void testCall() {
get("test")
.then()
.statusCode(200)
.body(equalTo("42"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.quarkiverse.langchain4j.test;

import static io.restassured.RestAssured.get;
import static org.hamcrest.Matchers.equalTo;

import java.util.function.Supplier;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;
import io.smallrye.mutiny.Multi;

public class StreamingAndBlockingChatLanguageModelSupplierTest {
@Path("/test")
static class MyResource {

private final MyService service;

MyResource(MyService service) {
this.service = service;
}

@GET
@Path("/blocking")
public String blocking() {
return service.blocking("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
}

@GET
@Path("/streaming")
public Multi<String> streaming() {
return service.streaming("what is the Answer to the Ultimate Question of Life, the Universe, and Everything?");
}
}

public static class MyModelSupplier implements Supplier<ChatLanguageModel> {
@Override
public ChatLanguageModel get() {
return (messages) -> new Response<>(new AiMessage("42"));
}
}

public static class MyStreamingModelSupplier implements Supplier<StreamingChatLanguageModel> {
@Override
public StreamingChatLanguageModel get() {
return (messages, handler) -> {
handler.onNext("4");
handler.onNext("2");
handler.onComplete(new Response<>(new AiMessage("")));
};
}
}

@RegisterAiService(chatLanguageModelSupplier = MyModelSupplier.class, streamingChatLanguageModelSupplier = MyStreamingModelSupplier.class, chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
interface MyService {
Multi<String> streaming(String msg);

String blocking(String msg);
}

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyResource.class, MyService.class, MyModelSupplier.class, MyStreamingModelSupplier.class));

@Test
public void testCalls() {
get("test/blocking")
.then()
.statusCode(200)
.body(equalTo("42"));
get("test/streaming")
.then()
.statusCode(200)
.body(equalTo("42"));
}
}
Loading

0 comments on commit 664087b

Please sign in to comment.