Skip to content

Commit

Permalink
Merge pull request #60 from cescoffier/in-process-model
Browse files Browse the repository at this point in the history
Add support for in-process embedding models
  • Loading branch information
geoand authored Nov 24, 2023
2 parents 3700027 + a6f7750 commit c0f9314
Show file tree
Hide file tree
Showing 42 changed files with 1,462 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,13 @@
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.poi.ss.formula.functions.T;
import org.jboss.jandex.DotName;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ModerationModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ProviderHolder;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedModerationModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.*;
import io.quarkiverse.langchain4j.runtime.Langchain4jRecorder;
import io.quarkus.arc.deployment.BeanDiscoveryFinishedBuildItem;
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
Expand Down Expand Up @@ -57,7 +52,8 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
LangChain4jBuildConfig buildConfig,
BuildProducer<SelectedChatModelProviderBuildItem> selectedChatProducer,
BuildProducer<SelectedEmbeddingModelCandidateBuildItem> selectedEmbeddingProducer,
BuildProducer<SelectedModerationModelProviderBuildItem> selectedModerationProducer) {
BuildProducer<SelectedModerationModelProviderBuildItem> selectedModerationProducer,
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems) {

boolean chatModelBeanRequested = false;
boolean streamingChatModelBeanRequested = false;
Expand Down Expand Up @@ -91,7 +87,8 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
if (embeddingModelBeanRequested) {
selectedEmbeddingProducer.produce(
new SelectedEmbeddingModelCandidateBuildItem(
selectProvider(
selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
buildConfig.embeddingModel().provider(),
"EmbeddingModel",
Expand All @@ -110,7 +107,8 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private <T extends ProviderHolder> String selectProvider(List<T> chatCandidateItems,
private <T extends ProviderHolder> String selectProvider(
List<T> chatCandidateItems,
Optional<String> userSelectedProvider,
String requestedBeanName,
String configNamespace) {
Expand Down Expand Up @@ -139,6 +137,41 @@ private <T extends ProviderHolder> String selectProvider(List<T> chatCandidateIt
requestedBeanName, configNamespace, String.join(",", availableProviders)));
}

private <T extends ProviderHolder> String selectEmbeddingModelProvider(
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems,
List<T> chatCandidateItems,
Optional<String> userSelectedProvider,
String requestedBeanName,
String configNamespace) {
List<String> availableProviders = chatCandidateItems.stream().map(ProviderHolder::getProvider)
.collect(Collectors.toList());
availableProviders.addAll(inProcessEmbeddingBuildItems.stream().map(InProcessEmbeddingBuildItem::getProvider)
.toList());
if (availableProviders.isEmpty()) {
throw new ConfigurationException(String.format(
"A %s bean was requested, but no langchain4j providers were configured and no in-process embedding model were found on the classpath. "
+
"Consider adding an extension like 'quarkus-langchain4j-openai' or one of the in-process embedding model.",
requestedBeanName));
}
if (availableProviders.size() == 1) {
return availableProviders.get(0);
}
// multiple providers exist, so we now need the configuration to select the proper one
if (userSelectedProvider.isEmpty()) {
throw new ConfigurationException(String.format(
"A %s bean was requested, but since there are multiple available providers, the 'quarkus.langchain4j.%s.provider' needs to be set to one of the available options (%s).",
requestedBeanName, configNamespace, String.join(",", availableProviders)));
}
boolean matches = availableProviders.stream().anyMatch(ap -> ap.equals(userSelectedProvider.get()));
if (matches) {
return userSelectedProvider.get();
}
throw new ConfigurationException(String.format(
"A %s bean was requested, but the value of 'quarkus.langchain4j.%s.provider' does not match any of the available options (%s).",
requestedBeanName, configNamespace, String.join(",", availableProviders)));
}

@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
public void cleanUp(Langchain4jRecorder recorder, ShutdownContextBuildItem shutdown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
import java.util.zip.ZipEntry;

import dev.langchain4j.data.document.splitter.DocumentBySentenceSplitter;
import io.quarkiverse.langchain4j.deployment.items.InProcessEmbeddingBuildItem;
import io.quarkus.bootstrap.classloading.QuarkusClassLoader;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.builditem.nativeimage.NativeImageResourceBuildItem;
import io.quarkus.deployment.builditem.nativeimage.NativeImageResourcePatternsBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.RuntimeInitializedPackageBuildItem;
import io.quarkus.deployment.builditem.nativeimage.*;

/**
* TODO: we might want to make this more granular so all these document related dependencies don't always end up in the
Expand All @@ -26,24 +23,33 @@
public class DocumentNativeSupportProcessor {

@BuildStep
void onnxJni(BuildProducer<NativeImageResourcePatternsBuildItem> nativePatternProducer,
void onnxJni(
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems,
BuildProducer<NativeImageResourcePatternsBuildItem> nativePatternProducer,
BuildProducer<ReflectiveClassBuildItem> reflectionProducer) {
// TODO: we can do better here and only include the target architecture's libs
nativePatternProducer
.produce(NativeImageResourcePatternsBuildItem.builder().includeGlobs("ai/onnxruntime/native/**").build());
reflectionProducer
.produce(ReflectiveClassBuildItem.builder("opennlp.tools.sentdetect.SentenceDetectorFactory").build());
reflectionProducer.produce(ReflectiveClassBuildItem.builder("ai.onnxruntime.OnnxTensor").methods(true).build());
if (!inProcessEmbeddingBuildItems.isEmpty()) {
// TODO: we can do better here and only include the target architecture's libs
nativePatternProducer
.produce(NativeImageResourcePatternsBuildItem.builder().includeGlobs("ai/onnxruntime/native/**").build());
reflectionProducer
.produce(ReflectiveClassBuildItem.builder("opennlp.tools.sentdetect.SentenceDetectorFactory").build());
reflectionProducer.produce(ReflectiveClassBuildItem.builder("ai.onnxruntime.OnnxTensor").methods(true).build());
}
}

@BuildStep
void apachePoiRuntimeClasses(BuildProducer<RuntimeInitializedClassBuildItem> classProducer,
void apachePoiRuntimeClasses(
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems,
BuildProducer<RuntimeInitializedClassBuildItem> classProducer,
BuildProducer<RuntimeInitializedPackageBuildItem> packageProducer) {
Stream.of(
"dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel",
"dev.langchain4j.model.embedding.OnnxBertBiEncoder",
"ai.onnxruntime.OrtEnvironment",
"ai.onnxruntime.OnnxRuntime",
"ai.onnxruntime.OnnxTensorLike",
"ai.onnxruntime.OrtAllocator",
"ai.onnxruntime.OrtSession$SessionOptions",
"ai.onnxruntime.OrtSession",
"org.apache.fontbox.ttf.RAFDataStream",
"org.apache.fontbox.ttf.TTFParser",
"org.apache.pdfbox.pdmodel.encrypetion.PublicKeySecurityHandler",
Expand All @@ -61,15 +67,33 @@ void apachePoiRuntimeClasses(BuildProducer<RuntimeInitializedClassBuildItem> cla
.filter(QuarkusClassLoader::isClassPresentAtRuntime)
.map(RuntimeInitializedClassBuildItem::new).forEach(classProducer::produce);

for (InProcessEmbeddingBuildItem inProcessEmbeddingBuildItem : inProcessEmbeddingBuildItems) {
classProducer.produce(new RuntimeInitializedClassBuildItem(inProcessEmbeddingBuildItem.className()));
}

packageProducer.produce(new RuntimeInitializedPackageBuildItem("com.microsoft.schemas.office"));
}

@BuildStep
void openNLPResources(BuildProducer<NativeImageResourceBuildItem> producer) {
void includeInProcessEmbeddingModels(
List<InProcessEmbeddingBuildItem> inProcessEmbeddingBuildItems,
BuildProducer<NativeImageResourceBuildItem> resources,
BuildProducer<ReflectiveClassBuildItem> reflection) {
for (InProcessEmbeddingBuildItem inProcessEmbeddingBuildItem : inProcessEmbeddingBuildItems) {
resources.produce(new NativeImageResourceBuildItem(inProcessEmbeddingBuildItem.onnxModelPath()));
resources.produce(new NativeImageResourceBuildItem(inProcessEmbeddingBuildItem.vocabularyPath()));
reflection.produce(ReflectiveClassBuildItem.builder(inProcessEmbeddingBuildItem.className())
.constructors(true)
.fields(true)
.methods(true)
.build());
}
}

@BuildStep
void openNLPResources(
BuildProducer<NativeImageResourceBuildItem> producer) {
registerCustomOpenNLPResources(producer);
// TODO: maybe we should also be opening jars and getting these?
producer.produce(new NativeImageResourceBuildItem("bert-vocabulary-en.txt"));
producer.produce(new NativeImageResourceBuildItem("all-minilm-l6-v2.onnx"));
}

private void registerCustomOpenNLPResources(BuildProducer<NativeImageResourceBuildItem> resourcesProducer) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package io.quarkiverse.langchain4j.deployment;

import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.DotName;

import dev.langchain4j.model.embedding.EmbeddingModel;
import io.quarkiverse.langchain4j.deployment.items.InProcessEmbeddingBuildItem;
import io.quarkiverse.langchain4j.runtime.InProcessEmbeddingRecorder;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.bootstrap.classloading.QuarkusClassLoader;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;

/**
* Generate a local embedding build item for each local embedding model available in the classpath.
* Note that the user must have the dependency for the model in their pom.xml/build.gradle.
*/
public class InProcessEmbeddingProcessor {

// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
@BuildStep
InProcessEmbeddingBuildItem all_minilm_l6_v2_q() {
if (QuarkusClassLoader
.isClassPresentAtRuntime("dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel")) {
return new InProcessEmbeddingBuildItem("all-minilm-l6-v2-q",
"dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel",
"all-minilm-l6-v2-q.onnx", "bert-vocabulary-en.txt");
} else {
return null;
}
}

// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
@BuildStep
InProcessEmbeddingBuildItem all_minilm_l6_v2() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel")) {
return new InProcessEmbeddingBuildItem("all-minilm-l6-v2",
"dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel",
"all-minilm-l6-v2.onnx", "bert-vocabulary-en.txt");
} else {
return null;
}
}

// https://huggingface.co/neuralmagic/bge-small-en-v1.5-quant
@BuildStep
InProcessEmbeddingBuildItem bge_small_en_q() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.BgeSmallEnQuantizedEmbeddingModel")) {
return new InProcessEmbeddingBuildItem("bge-small-en-q",
"dev.langchain4j.model.embedding.BgeSmallEnQuantizedEmbeddingModel",
"bge-small-en-q.onnx", "bert-vocabulary-en.txt");
} else {
return null;
}
}

// https://huggingface.co/BAAI/bge-small-en-v1.5
@BuildStep
InProcessEmbeddingBuildItem bge_small_en() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.BgeSmallEnEmbeddingModel")) {
return new InProcessEmbeddingBuildItem("bge-small-en", "dev.langchain4j.model.embedding.BgeSmallEnEmbeddingModel",
"bge-small-en.onnx", "bert-vocabulary-en.txt");
} else {
return null;
}
}

@BuildStep
InProcessEmbeddingBuildItem bge_small_zh_q() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.BgeSmallZhQuantizedEmbeddingModel")) {
return new InProcessEmbeddingBuildItem("bge-small-zh-q",
"dev.langchain4j.model.embedding.BgeSmallZhQuantizedEmbeddingModel",
"bge-small-zh-q.onnx", "bge-small-zh-vocabulary.txt");
} else {
return null;
}
}

// https://huggingface.co/BAAI/bge-small-zh
@BuildStep
InProcessEmbeddingBuildItem bge_small_zh() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel")) {
return new InProcessEmbeddingBuildItem("bge-small-zh", "dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel",
"bge-small-zh.onnx", "bge-small-zh-vocabulary.txt");
} else {
return null;
}
}

// https://huggingface.co/intfloat/e5-small-v2
@BuildStep
InProcessEmbeddingBuildItem e5_small_v2_q() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.E5SmallV2QuantizedEmbeddingModel")) {
return new InProcessEmbeddingBuildItem("e5-small-v2-q",
"dev.langchain4j.model.embedding.E5SmallV2QuantizedEmbeddingModel",
"e5-small-v2-q.onnx", "bert-vocabulary-en.txt");
} else {
return null;
}
}

// https://huggingface.co/intfloat/e5-small-v2
@BuildStep
InProcessEmbeddingBuildItem e5_small_v2() {
if (QuarkusClassLoader.isClassPresentAtRuntime("dev.langchain4j.model.embedding.E5SmallV2EmbeddingModel")) {
return new InProcessEmbeddingBuildItem("e5-small-v2", "dev.langchain4j.model.embedding.E5SmallV2EmbeddingModel",
"e5-small-v2.onnx", "bert-vocabulary-en.txt");
} else {
return null;
}
}

// Expose a bean for each in process embedding model
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void exposeInProcessEmbeddingBeans(InProcessEmbeddingRecorder recorder,
List<InProcessEmbeddingBuildItem> embeddings,
BuildProducer<SyntheticBeanBuildItem> beanProducer) {

for (InProcessEmbeddingBuildItem embedding : embeddings) {
beanProducer.produce(SyntheticBeanBuildItem
.configure(DotName.createSimple(embedding.className()))
.types(EmbeddingModel.class)
.defaultBean()
.setRuntimeInit()
.scope(ApplicationScoped.class)
.supplier(recorder.instantiate(embedding.className()))
.done());
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.quarkiverse.langchain4j.deployment.items;

import io.quarkus.builder.item.MultiBuildItem;

public final class InProcessEmbeddingBuildItem extends MultiBuildItem implements ProviderHolder {

private final String modelName;
private final String onnxModelPath;
private final String vocabularyPath;

private final String className;

public InProcessEmbeddingBuildItem(String modelName, String className, String onnxModelPath, String vocabularyPath) {
this.modelName = modelName;
this.className = className;
this.onnxModelPath = onnxModelPath;
this.vocabularyPath = vocabularyPath;
}

public String modelName() {
return modelName;
}

public String onnxModelPath() {
return onnxModelPath;
}

public String vocabularyPath() {
return vocabularyPath;
}

public String className() {
return className;
}

@Override
public String getProvider() {
return className;
}
}
Loading

0 comments on commit c0f9314

Please sign in to comment.