Skip to content

Commit

Permalink
Register all models from configuration as beans regardless of the exi…
Browse files Browse the repository at this point in the history
…stance of injection point
  • Loading branch information
manovotn committed Nov 8, 2024
1 parent 6efb975 commit b1c52cb
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
Expand Down Expand Up @@ -100,9 +101,36 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
Set<String> requestedImageModels = new HashSet<>();
Set<String> tokenCountEstimators = new HashSet<>();

// detection of injection points for default models
boolean defaultChatModelRequested = false;
boolean defaultScoringModelRequested = false;
boolean defaultEmbeddingModelRequested = false;
boolean defaultModerationModelRequested = false;
boolean defaultImageModelRequested = false;

// default model names
final String chatModelConfigNamespace = "chat-model";
final String embeddingModelConfigNamespace = "embedding-model";
final String scoringModelConfigNamespace = "scoring-model";
final String moderationModelConfigNamespace = "moderation-model";
final String imageModelConfigNamespace = "image-model";

// separator symbol for named configs
final String dot = ".";

// bean types for models
final String chatModelBeanType = "ChatLanguageModel or StreamingChatLanguageModel";
final String embeddingModelBeanType = "EmbeddingModel";
final String scoringModelBeanType = "ScoringModel";
final String moderationModelBeanType = "ModerationModel";
final String imageModelBeanType = "ImageModel";

for (InjectionPointInfo ip : beanDiscoveryFinished.getInjectionPoints()) {
DotName requiredName = ip.getRequiredType().name();
String modelName = determineModelName(ip);
if (modelName == null) {
continue;
}
if (CHAT_MODEL.equals(requiredName)) {
requestedChatModels.add(modelName);
} else if (STREAMING_CHAT_MODEL.equals(requiredName)) {
Expand Down Expand Up @@ -140,14 +168,15 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().chatModel().provider();
configNamespace = "chat-model";
configNamespace = chatModelConfigNamespace;
defaultChatModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).chatModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".chat-model";
configNamespace = modelName + dot + chatModelConfigNamespace;
}
if (userSelectedProvider.isEmpty() && !NamedConfigUtil.isDefault(modelName)) {
// let's see if the user has configured a model name for one of the named providers
Expand All @@ -163,7 +192,7 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
chatCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ChatLanguageModel.class),
userSelectedProvider,
"ChatLanguageModel or StreamingChatLanguageModel",
chatModelBeanType,
configNamespace);
if (provider != null) {
selectedChatProducer.produce(new SelectedChatModelProviderBuildItem(provider, modelName));
Expand All @@ -177,21 +206,22 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().scoringModel().provider();
configNamespace = "scoring-model";
configNamespace = scoringModelConfigNamespace;
defaultScoringModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).scoringModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".scoring-model";
configNamespace = modelName + dot + scoringModelConfigNamespace;
}

String provider = selectProvider(
scoringCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
userSelectedProvider,
"ScoringModel",
scoringModelBeanType,
configNamespace);
if (provider != null) {
selectedScoringProducer.produce(new SelectedScoringModelProviderBuildItem(provider, modelName));
Expand All @@ -203,22 +233,23 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().embeddingModel().provider();
configNamespace = "embedding-model";
configNamespace = embeddingModelConfigNamespace;
defaultEmbeddingModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).embeddingModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".embedding-model";
configNamespace = modelName + dot + embeddingModelConfigNamespace;
}

String provider = selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider,
"EmbeddingModel",
embeddingModelBeanType,
configNamespace);
if (provider != null) {
selectedEmbeddingProducer.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, modelName));
Expand All @@ -231,7 +262,7 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
Optional<String> userSelectedProvider = buildConfig.defaultConfig().embeddingModel().provider();
String provider = selectEmbeddingModelProvider(inProcessEmbeddingBuildItems, embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider, "EmbeddingModel", "embedding-model");
userSelectedProvider, embeddingModelBeanType, embeddingModelConfigNamespace);
selectedEmbeddingProducer
.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
Expand All @@ -241,21 +272,22 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().moderationModel().provider();
configNamespace = "moderation-model";
configNamespace = moderationModelConfigNamespace;
defaultModerationModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).moderationModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".moderation-model";
configNamespace = modelName + dot + moderationModelConfigNamespace;
}

String provider = selectProvider(
moderationCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ModerationModel.class),
userSelectedProvider,
"ModerationModel",
moderationModelBeanType,
configNamespace);
if (provider != null) {
selectedModerationProducer.produce(new SelectedModerationModelProviderBuildItem(provider, modelName));
Expand All @@ -267,27 +299,173 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().imageModel().provider();
configNamespace = "image-model";
configNamespace = imageModelConfigNamespace;
defaultImageModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).imageModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".image-model";
configNamespace = modelName + dot + imageModelConfigNamespace;
}

String provider = selectProvider(
imageCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ImageModel.class),
userSelectedProvider,
"ImageModel",
imageModelBeanType,
configNamespace);
if (provider != null) {
selectedImageProducer.produce(new SelectedImageModelProviderBuildItem(provider, modelName));
}
}

// There can be configured models for which we found no injection points.
// While we cannot perform full validation of those, we can still add them as beans.
// This enabled injection such as @Inject @Any Instance<ChatLanguageModel>

// process default configuration
LangChain4jBuildConfig.BaseConfig defaultConfig = buildConfig.defaultConfig();
if (!defaultChatModelRequested && !defaultConfig.chatModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.chatModel().provider();
String provider = selectProvider(
chatCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ChatLanguageModel.class),
userSelectedProvider,
chatModelBeanType,
chatModelConfigNamespace);
if (provider != null) {
selectedChatProducer.produce(new SelectedChatModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}
if (!defaultEmbeddingModelRequested && !defaultConfig.embeddingModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.embeddingModel().provider();
String provider = selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider,
embeddingModelBeanType,
embeddingModelConfigNamespace);
if (provider != null) {
selectedEmbeddingProducer
.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}

if (!defaultScoringModelRequested && !defaultConfig.scoringModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.scoringModel().provider();
String provider = selectProvider(
scoringCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
userSelectedProvider,
scoringModelBeanType,
scoringModelConfigNamespace);
if (provider != null) {
selectedScoringProducer
.produce(new SelectedScoringModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}
if (!defaultModerationModelRequested && !defaultConfig.moderationModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.moderationModel().provider();
String provider = selectProvider(
moderationCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ModerationModel.class),
userSelectedProvider,
moderationModelBeanType,
moderationModelConfigNamespace);
if (provider != null) {
selectedModerationProducer
.produce(new SelectedModerationModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}
if (!defaultImageModelRequested && !defaultConfig.imageModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.imageModel().provider();
String provider = selectProvider(
imageCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ImageModel.class),
userSelectedProvider,
imageModelBeanType,
imageModelConfigNamespace);
if (provider != null) {
selectedImageProducer.produce(new SelectedImageModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}

// process named configuration
for (Map.Entry<String, LangChain4jBuildConfig.BaseConfig> entry : buildConfig.namedConfig().entrySet()) {
LangChain4jBuildConfig.BaseConfig value = entry.getValue();
if (!requestedStreamingChatModels.contains(entry.getKey()) &&
!requestedChatModels.contains(entry.getKey()) &&
!value.chatModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.chatModel().provider();
String configNamespace = entry.getKey() + dot + chatModelConfigNamespace;
String provider = selectProvider(
chatCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ChatLanguageModel.class),
userSelectedProvider,
chatModelBeanType,
configNamespace);
if (provider != null) {
selectedChatProducer.produce(new SelectedChatModelProviderBuildItem(provider, entry.getKey()));
}
}
if (!requestEmbeddingModels.contains(entry.getKey()) && !value.embeddingModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.embeddingModel().provider();
String configNamespace = entry.getKey() + dot + embeddingModelConfigNamespace;
String provider = selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider,
embeddingModelBeanType,
configNamespace);
if (provider != null) {
selectedEmbeddingProducer.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, entry.getKey()));
}
}
if (!requestScoringModels.contains(entry.getKey()) && !value.scoringModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.scoringModel().provider();
String configNamespace = entry.getKey() + dot + scoringModelConfigNamespace;
String provider = selectProvider(
scoringCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
userSelectedProvider,
scoringModelBeanType,
configNamespace);
if (provider != null) {
selectedScoringProducer.produce(new SelectedScoringModelProviderBuildItem(provider, entry.getKey()));
}
}
if (!requestedModerationModels.contains(entry.getKey()) && !value.moderationModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.moderationModel().provider();
String configNamespace = entry.getKey() + dot + moderationModelConfigNamespace;
String provider = selectProvider(
moderationCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ModerationModel.class),
userSelectedProvider,
moderationModelBeanType,
configNamespace);
if (provider != null) {
selectedModerationProducer.produce(new SelectedModerationModelProviderBuildItem(provider, entry.getKey()));
}
}
if (!requestedImageModels.contains(entry.getKey()) && !value.imageModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.imageModel().provider();
String configNamespace = entry.getKey() + dot + imageModelConfigNamespace;
String provider = selectProvider(
imageCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ImageModel.class),
userSelectedProvider,
imageModelBeanType,
configNamespace);
if (provider != null) {
selectedImageProducer.produce(new SelectedImageModelProviderBuildItem(provider, entry.getKey()));
}
}
}

}

private String determineModelName(InjectionPointInfo ip) {
Expand All @@ -298,6 +476,10 @@ private String determineModelName(InjectionPointInfo ip) {
return value;
}
}
// @Inject @Any Instance<Foo> should not be treated as default name
if (modelNameInstance == null && ip.isProgrammaticLookup()) {
return null;
}
return NamedConfigUtil.DEFAULT_NAME;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,15 @@ quarkus.langchain4j.watsonx.s1.api-key=test
quarkus.langchain4j.watsonx.s1.project-id=proj
quarkus.langchain4j.s2.scoring-model.provider=cohere
quarkus.langchain4j.cohere.s2.api-key=test

# Following models intentionally have no explicit injection points in tests
quarkus.langchain4j.c10.chat-model.provider=watsonx
quarkus.langchain4j.watsonx.c10.base-url=https://somecluster.somedomain.ai:443/api
quarkus.langchain4j.watsonx.c10.api-key=test9
quarkus.langchain4j.watsonx.c10.project-id=proj
quarkus.langchain4j.watsonx.c10.mode=generation
quarkus.langchain4j.e4.embedding-model.provider=ollama
quarkus.langchain4j.c11.moderation-model.provider=openai
quarkus.langchain4j.openai.c11.api-key=test2
quarkus.langchain4j.s3.scoring-model.provider=cohere
quarkus.langchain4j.cohere.s3.api-key=test
Loading

0 comments on commit b1c52cb

Please sign in to comment.