Skip to content

Commit

Permalink
Merge pull request #258 from quarkiverse/multi-model
Browse files Browse the repository at this point in the history
  • Loading branch information
cescoffier authored Feb 2, 2024
2 parents 6dd0d43 + af65a96 commit c9bf7a9
Show file tree
Hide file tree
Showing 70 changed files with 3,299 additions and 812 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.CHAT_MODEL;
import static io.quarkiverse.langchain4j.deployment.Langchain4jDotNames.EMBEDDING_MODEL;

import java.util.Optional;
import java.util.List;

import jakarta.enterprise.context.ApplicationScoped;

import org.jboss.jandex.AnnotationInstance;

import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.bam.runtime.BamRecorder;
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
import io.quarkiverse.langchain4j.deployment.items.ChatModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.EmbeddingModelProviderCandidateBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedEmbeddingModelCandidateBuildItem;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
Expand Down Expand Up @@ -49,31 +53,43 @@ public void providerCandidates(BuildProducer<ChatModelProviderCandidateBuildItem
@BuildStep
@Record(ExecutionTime.RUNTIME_INIT)
void generateBeans(BamRecorder recorder,
Optional<SelectedChatModelProviderBuildItem> selectedChatItem,
Optional<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
List<SelectedChatModelProviderBuildItem> selectedChatItem,
List<SelectedEmbeddingModelCandidateBuildItem> selectedEmbedding,
Langchain4jBamConfig config,
BuildProducer<SyntheticBeanBuildItem> beanProducer) {

if (selectedChatItem.isPresent() && PROVIDER.equals(selectedChatItem.get().getProvider())) {
beanProducer.produce(SyntheticBeanBuildItem
.configure(CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config))
.done());
for (var selected : selectedChatItem) {
if (PROVIDER.equals(selected.getProvider())) {
String modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(CHAT_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.chatModel(config, modelName));
addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}

for (var selected : selectedEmbedding) {
if (PROVIDER.equals(selected.getProvider())) {
String modelName = selected.getModelName();
var builder = SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config, modelName));
addQualifierIfNecessary(builder, modelName);
beanProducer.produce(builder.done());
}
}
}

if (selectedEmbedding.isPresent() && PROVIDER.equals(selectedEmbedding.get().getProvider())) {
beanProducer.produce(
SyntheticBeanBuildItem
.configure(EMBEDDING_MODEL)
.setRuntimeInit()
.defaultBean()
.scope(ApplicationScoped.class)
.supplier(recorder.embeddingModel(config))
.unremovable()
.done());
private void addQualifierIfNecessary(SyntheticBeanBuildItem.ExtendedBeanConfigurator builder, String modelName) {
if (!NamedModelUtil.isDefault(modelName)) {
builder.addQualifier(AnnotationInstance.builder(ModelName.class).add("value", modelName).build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ interface NewAIService {
NewAIService service;

@Inject
Langchain4jBamConfig config;
Langchain4jBamConfig langchain4jBamConfig;

@Test
void chat() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

var modelId = config.chatModel().modelId();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class AllPropertiesTest {
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@Inject
Langchain4jBamConfig config;
Langchain4jBamConfig langchain4jBamConfig;

@Inject
ChatLanguageModel model;
Expand All @@ -79,6 +79,7 @@ static void afterAll() {

@Test
void generate() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

assertEquals(WireMockUtil.URL, config.baseUrl().get().toString());
assertEquals(WireMockUtil.API_KEY, config.apiKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class DefaultPropertiesTest {
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClass(WireMockUtil.class));

@Inject
Langchain4jBamConfig config;
Langchain4jBamConfig langchain4jBamConfig;

@Inject
ChatLanguageModel model;
Expand All @@ -59,6 +59,7 @@ static void afterAll() {

@Test
void generate() throws Exception {
var config = langchain4jBamConfig.defaultConfig();

assertEquals(Duration.ofSeconds(10), config.timeout());
assertEquals("2024-01-10", config.version());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.quarkiverse.langchain4j.bam.BamException.Code;
import io.quarkiverse.langchain4j.bam.BamException.Reason;
import io.quarkiverse.langchain4j.bam.BamRestApi;
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
import io.quarkus.test.QuarkusUnitTest;

public class HttpErrorTest {
Expand All @@ -36,9 +35,6 @@ public class HttpErrorTest {
static ObjectMapper mapper;
static WireMockUtil mockServers;

@Inject
Langchain4jBamConfig config;

@Inject
ChatLanguageModel model;

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,32 @@
import io.quarkiverse.langchain4j.bam.BamChatModel;
import io.quarkiverse.langchain4j.bam.BamEmbeddingModel;
import io.quarkiverse.langchain4j.bam.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.bam.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.bam.runtime.config.Langchain4jBamConfig;
import io.quarkiverse.langchain4j.runtime.NamedModelUtil;
import io.quarkus.runtime.annotations.Recorder;
import io.smallrye.config.ConfigValidationException;

@Recorder
public class BamRecorder {

public Supplier<?> chatModel(Langchain4jBamConfig runtimeConfig) {
ChatModelConfig chatModelConfig = runtimeConfig.chatModel();
private static final String DUMMY_KEY = "dummy";

public Supplier<?> chatModel(Langchain4jBamConfig runtimeConfig, String modelName) {
Langchain4jBamConfig.BamConfig bamConfig = correspondingBamConfig(runtimeConfig, modelName);
ChatModelConfig chatModelConfig = bamConfig.chatModel();
String apiKey = bamConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}

var builder = BamChatModel.builder()
.accessToken(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.accessToken(bamConfig.apiKey())
.timeout(bamConfig.timeout())
.logRequests(bamConfig.logRequests())
.logResponses(bamConfig.logResponses())
.modelId(chatModelConfig.modelId())
.version(runtimeConfig.version())
.version(bamConfig.version())
.decodingMethod(chatModelConfig.decodingMethod())
.minNewTokens(chatModelConfig.minNewTokens())
.maxNewTokens(chatModelConfig.maxNewTokens())
Expand All @@ -38,8 +48,8 @@ public Supplier<?> chatModel(Langchain4jBamConfig runtimeConfig) {
.truncateInputTokens(firstOrDefault(null, chatModelConfig.truncateInputTokens()))
.beamWidth(firstOrDefault(null, chatModelConfig.beamWidth()));

if (runtimeConfig.baseUrl().isPresent()) {
builder.url(runtimeConfig.baseUrl().get());
if (bamConfig.baseUrl().isPresent()) {
builder.url(bamConfig.baseUrl().get());
}

return new Supplier<>() {
Expand All @@ -50,18 +60,22 @@ public Object get() {
};
}

public Supplier<?> embeddingModel(Langchain4jBamConfig runtimeConfig) {

var embeddingModelConfig = runtimeConfig.embeddingModel();
public Supplier<?> embeddingModel(Langchain4jBamConfig runtimeConfig, String modelName) {
Langchain4jBamConfig.BamConfig bamConfig = correspondingBamConfig(runtimeConfig, modelName);
EmbeddingModelConfig embeddingModelConfig = bamConfig.embeddingModel();
String apiKey = bamConfig.apiKey();
if (DUMMY_KEY.equals(apiKey)) {
throw new ConfigValidationException(createApiKeyConfigProblem(modelName));
}

var builder = BamEmbeddingModel.builder()
.accessToken(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.version(runtimeConfig.version())
.accessToken(bamConfig.apiKey())
.timeout(bamConfig.timeout())
.version(bamConfig.version())
.modelId(embeddingModelConfig.modelId());

if (runtimeConfig.baseUrl().isPresent()) {
builder.url(runtimeConfig.baseUrl().get());
if (bamConfig.baseUrl().isPresent()) {
builder.url(bamConfig.baseUrl().get());
}

return new Supplier<>() {
Expand All @@ -71,4 +85,28 @@ public Object get() {
}
};
}

private Langchain4jBamConfig.BamConfig correspondingBamConfig(Langchain4jBamConfig runtimeConfig, String modelName) {
Langchain4jBamConfig.BamConfig bamConfig;
if (NamedModelUtil.isDefault(modelName)) {
bamConfig = runtimeConfig.defaultConfig();
} else {
bamConfig = runtimeConfig.namedConfig().get(modelName);
}
return bamConfig;
}

private ConfigValidationException.Problem[] createApiKeyConfigProblem(String modelName) {
return createConfigProblems("api-key", modelName);
}

private ConfigValidationException.Problem[] createConfigProblems(String key, String modelName) {
return new ConfigValidationException.Problem[] { createConfigProblem(key, modelName) };
}

private static ConfigValidationException.Problem createConfigProblem(String key, String modelName) {
return new ConfigValidationException.Problem(String.format(
"SRCFG00014: The config property quarkus.langchain4j.bam%s%s is required but it could not be found in any config source",
NamedModelUtil.isDefault(modelName) ? "." : ("." + modelName + "."), key));
}
}
Loading

0 comments on commit c9bf7a9

Please sign in to comment.