Skip to content

Commit

Permalink
Merge pull request #1061 from manovotn/registerAllKnownModels
Browse files Browse the repository at this point in the history
Register all models from configuration as beans regardless of the existance of injection point
  • Loading branch information
geoand authored Nov 8, 2024
2 parents 6752844 + b1c52cb commit 084bcf1
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
Expand Down Expand Up @@ -100,9 +101,36 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
Set<String> requestedImageModels = new HashSet<>();
Set<String> tokenCountEstimators = new HashSet<>();

// detection of injection points for default models
boolean defaultChatModelRequested = false;
boolean defaultScoringModelRequested = false;
boolean defaultEmbeddingModelRequested = false;
boolean defaultModerationModelRequested = false;
boolean defaultImageModelRequested = false;

// default model names
final String chatModelConfigNamespace = "chat-model";
final String embeddingModelConfigNamespace = "embedding-model";
final String scoringModelConfigNamespace = "scoring-model";
final String moderationModelConfigNamespace = "moderation-model";
final String imageModelConfigNamespace = "image-model";

// separator symbol for named configs
final String dot = ".";

// bean types for models
final String chatModelBeanType = "ChatLanguageModel or StreamingChatLanguageModel";
final String embeddingModelBeanType = "EmbeddingModel";
final String scoringModelBeanType = "ScoringModel";
final String moderationModelBeanType = "ModerationModel";
final String imageModelBeanType = "ImageModel";

for (InjectionPointInfo ip : beanDiscoveryFinished.getInjectionPoints()) {
DotName requiredName = ip.getRequiredType().name();
String modelName = determineModelName(ip);
if (modelName == null) {
continue;
}
if (CHAT_MODEL.equals(requiredName)) {
requestedChatModels.add(modelName);
} else if (STREAMING_CHAT_MODEL.equals(requiredName)) {
Expand Down Expand Up @@ -140,14 +168,15 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().chatModel().provider();
configNamespace = "chat-model";
configNamespace = chatModelConfigNamespace;
defaultChatModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).chatModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".chat-model";
configNamespace = modelName + dot + chatModelConfigNamespace;
}
if (userSelectedProvider.isEmpty() && !NamedConfigUtil.isDefault(modelName)) {
// let's see if the user has configured a model name for one of the named providers
Expand All @@ -163,7 +192,7 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
chatCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ChatLanguageModel.class),
userSelectedProvider,
"ChatLanguageModel or StreamingChatLanguageModel",
chatModelBeanType,
configNamespace);
if (provider != null) {
selectedChatProducer.produce(new SelectedChatModelProviderBuildItem(provider, modelName));
Expand All @@ -177,21 +206,22 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().scoringModel().provider();
configNamespace = "scoring-model";
configNamespace = scoringModelConfigNamespace;
defaultScoringModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).scoringModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".scoring-model";
configNamespace = modelName + dot + scoringModelConfigNamespace;
}

String provider = selectProvider(
scoringCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
userSelectedProvider,
"ScoringModel",
scoringModelBeanType,
configNamespace);
if (provider != null) {
selectedScoringProducer.produce(new SelectedScoringModelProviderBuildItem(provider, modelName));
Expand All @@ -203,22 +233,23 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().embeddingModel().provider();
configNamespace = "embedding-model";
configNamespace = embeddingModelConfigNamespace;
defaultEmbeddingModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).embeddingModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".embedding-model";
configNamespace = modelName + dot + embeddingModelConfigNamespace;
}

String provider = selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider,
"EmbeddingModel",
embeddingModelBeanType,
configNamespace);
if (provider != null) {
selectedEmbeddingProducer.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, modelName));
Expand All @@ -231,7 +262,7 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
Optional<String> userSelectedProvider = buildConfig.defaultConfig().embeddingModel().provider();
String provider = selectEmbeddingModelProvider(inProcessEmbeddingBuildItems, embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider, "EmbeddingModel", "embedding-model");
userSelectedProvider, embeddingModelBeanType, embeddingModelConfigNamespace);
selectedEmbeddingProducer
.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
Expand All @@ -241,21 +272,22 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().moderationModel().provider();
configNamespace = "moderation-model";
configNamespace = moderationModelConfigNamespace;
defaultModerationModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).moderationModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".moderation-model";
configNamespace = modelName + dot + moderationModelConfigNamespace;
}

String provider = selectProvider(
moderationCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ModerationModel.class),
userSelectedProvider,
"ModerationModel",
moderationModelBeanType,
configNamespace);
if (provider != null) {
selectedModerationProducer.produce(new SelectedModerationModelProviderBuildItem(provider, modelName));
Expand All @@ -267,27 +299,173 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
String configNamespace;
if (NamedConfigUtil.isDefault(modelName)) {
userSelectedProvider = buildConfig.defaultConfig().imageModel().provider();
configNamespace = "image-model";
configNamespace = imageModelConfigNamespace;
defaultImageModelRequested = true;
} else {
if (buildConfig.namedConfig().containsKey(modelName)) {
userSelectedProvider = buildConfig.namedConfig().get(modelName).imageModel().provider();
} else {
userSelectedProvider = Optional.empty();
}
configNamespace = modelName + ".image-model";
configNamespace = modelName + dot + imageModelConfigNamespace;
}

String provider = selectProvider(
imageCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ImageModel.class),
userSelectedProvider,
"ImageModel",
imageModelBeanType,
configNamespace);
if (provider != null) {
selectedImageProducer.produce(new SelectedImageModelProviderBuildItem(provider, modelName));
}
}

// There can be configured models for which we found no injection points.
// While we cannot perform full validation of those, we can still add them as beans.
// This enabled injection such as @Inject @Any Instance<ChatLanguageModel>

// process default configuration
LangChain4jBuildConfig.BaseConfig defaultConfig = buildConfig.defaultConfig();
if (!defaultChatModelRequested && !defaultConfig.chatModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.chatModel().provider();
String provider = selectProvider(
chatCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ChatLanguageModel.class),
userSelectedProvider,
chatModelBeanType,
chatModelConfigNamespace);
if (provider != null) {
selectedChatProducer.produce(new SelectedChatModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}
if (!defaultEmbeddingModelRequested && !defaultConfig.embeddingModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.embeddingModel().provider();
String provider = selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider,
embeddingModelBeanType,
embeddingModelConfigNamespace);
if (provider != null) {
selectedEmbeddingProducer
.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}

if (!defaultScoringModelRequested && !defaultConfig.scoringModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.scoringModel().provider();
String provider = selectProvider(
scoringCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
userSelectedProvider,
scoringModelBeanType,
scoringModelConfigNamespace);
if (provider != null) {
selectedScoringProducer
.produce(new SelectedScoringModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}
if (!defaultModerationModelRequested && !defaultConfig.moderationModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.moderationModel().provider();
String provider = selectProvider(
moderationCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ModerationModel.class),
userSelectedProvider,
moderationModelBeanType,
moderationModelConfigNamespace);
if (provider != null) {
selectedModerationProducer
.produce(new SelectedModerationModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}
if (!defaultImageModelRequested && !defaultConfig.imageModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = defaultConfig.imageModel().provider();
String provider = selectProvider(
imageCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ImageModel.class),
userSelectedProvider,
imageModelBeanType,
imageModelConfigNamespace);
if (provider != null) {
selectedImageProducer.produce(new SelectedImageModelProviderBuildItem(provider, NamedConfigUtil.DEFAULT_NAME));
}
}

// process named configuration
for (Map.Entry<String, LangChain4jBuildConfig.BaseConfig> entry : buildConfig.namedConfig().entrySet()) {
LangChain4jBuildConfig.BaseConfig value = entry.getValue();
if (!requestedStreamingChatModels.contains(entry.getKey()) &&
!requestedChatModels.contains(entry.getKey()) &&
!value.chatModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.chatModel().provider();
String configNamespace = entry.getKey() + dot + chatModelConfigNamespace;
String provider = selectProvider(
chatCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ChatLanguageModel.class),
userSelectedProvider,
chatModelBeanType,
configNamespace);
if (provider != null) {
selectedChatProducer.produce(new SelectedChatModelProviderBuildItem(provider, entry.getKey()));
}
}
if (!requestEmbeddingModels.contains(entry.getKey()) && !value.embeddingModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.embeddingModel().provider();
String configNamespace = entry.getKey() + dot + embeddingModelConfigNamespace;
String provider = selectEmbeddingModelProvider(
inProcessEmbeddingBuildItems,
embeddingCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(EmbeddingModel.class),
userSelectedProvider,
embeddingModelBeanType,
configNamespace);
if (provider != null) {
selectedEmbeddingProducer.produce(new SelectedEmbeddingModelCandidateBuildItem(provider, entry.getKey()));
}
}
if (!requestScoringModels.contains(entry.getKey()) && !value.scoringModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.scoringModel().provider();
String configNamespace = entry.getKey() + dot + scoringModelConfigNamespace;
String provider = selectProvider(
scoringCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ScoringModel.class),
userSelectedProvider,
scoringModelBeanType,
configNamespace);
if (provider != null) {
selectedScoringProducer.produce(new SelectedScoringModelProviderBuildItem(provider, entry.getKey()));
}
}
if (!requestedModerationModels.contains(entry.getKey()) && !value.moderationModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.moderationModel().provider();
String configNamespace = entry.getKey() + dot + moderationModelConfigNamespace;
String provider = selectProvider(
moderationCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ModerationModel.class),
userSelectedProvider,
moderationModelBeanType,
configNamespace);
if (provider != null) {
selectedModerationProducer.produce(new SelectedModerationModelProviderBuildItem(provider, entry.getKey()));
}
}
if (!requestedImageModels.contains(entry.getKey()) && !value.imageModel().provider().isEmpty()) {
Optional<String> userSelectedProvider = value.imageModel().provider();
String configNamespace = entry.getKey() + dot + imageModelConfigNamespace;
String provider = selectProvider(
imageCandidateItems,
beanDiscoveryFinished.beanStream().withBeanType(ImageModel.class),
userSelectedProvider,
imageModelBeanType,
configNamespace);
if (provider != null) {
selectedImageProducer.produce(new SelectedImageModelProviderBuildItem(provider, entry.getKey()));
}
}
}

}

private String determineModelName(InjectionPointInfo ip) {
Expand All @@ -298,6 +476,10 @@ private String determineModelName(InjectionPointInfo ip) {
return value;
}
}
// @Inject @Any Instance<Foo> should not be treated as default name
if (modelNameInstance == null && ip.isProgrammaticLookup()) {
return null;
}
return NamedConfigUtil.DEFAULT_NAME;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,15 @@ quarkus.langchain4j.watsonx.s1.api-key=test
quarkus.langchain4j.watsonx.s1.project-id=proj
quarkus.langchain4j.s2.scoring-model.provider=cohere
quarkus.langchain4j.cohere.s2.api-key=test

# Following models intentionally have no explicit injection points in tests
quarkus.langchain4j.c10.chat-model.provider=watsonx
quarkus.langchain4j.watsonx.c10.base-url=https://somecluster.somedomain.ai:443/api
quarkus.langchain4j.watsonx.c10.api-key=test9
quarkus.langchain4j.watsonx.c10.project-id=proj
quarkus.langchain4j.watsonx.c10.mode=generation
quarkus.langchain4j.e4.embedding-model.provider=ollama
quarkus.langchain4j.c11.moderation-model.provider=openai
quarkus.langchain4j.openai.c11.api-key=test2
quarkus.langchain4j.s3.scoring-model.provider=cohere
quarkus.langchain4j.cohere.s3.api-key=test
Loading

0 comments on commit 084bcf1

Please sign in to comment.