Skip to content

Commit

Permalink
Merge pull request #1024 from andreadimaio/main
Browse files Browse the repository at this point in the history
Fix incorrect model-id parameter in WatsonxRecorder
  • Loading branch information
geoand authored Oct 31, 2024
2 parents c458c3f + a828a32 commit 4fe7711
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ void handlerBeforeEach() {
@Test
void check_config() throws Exception {
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();
assertEquals(WireMockUtil.URL_WATSONX_SERVER, runtimeConfig.baseUrl().toString());
assertEquals(WireMockUtil.URL_WATSONX_SERVER, runtimeConfig.baseUrl().orElse(null).toString());
assertEquals(WireMockUtil.URL_IAM_SERVER, runtimeConfig.iam().baseUrl().toString());
assertEquals(WireMockUtil.API_KEY, runtimeConfig.apiKey());
assertEquals(WireMockUtil.API_KEY, runtimeConfig.apiKey().orElse(null));
assertEquals("my-space-id", runtimeConfig.spaceId().orElse(null));
assertEquals(WireMockUtil.PROJECT_ID, runtimeConfig.projectId().orElse(null));
assertEquals(Duration.ofSeconds(60), runtimeConfig.timeout().get());
Expand All @@ -114,7 +114,7 @@ void check_config() throws Exception {
@Test
void check_chat_model_config() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.generationModel().modelId();
String modelId = config.chatModel().modelId();
String spaceId = config.spaceId().orElse(null);
String projectId = config.projectId().orElse(null);

Expand All @@ -135,7 +135,7 @@ void check_chat_model_config() throws Exception {
@Test
void check_token_count_estimator() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.generationModel().modelId();
String modelId = config.chatModel().modelId();
String spaceId = config.spaceId().orElse(null);
String projectId = config.projectId().orElse(null);

Expand All @@ -152,7 +152,7 @@ void check_token_count_estimator() throws Exception {
@Test
void check_chat_streaming_model_config() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.generationModel().modelId();
String modelId = config.chatModel().modelId();
String spaceId = config.spaceId().orElse(null);
String projectId = config.projectId().orElse(null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void check_config() throws Exception {
@Test
void check_chat_model_config() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.generationModel().modelId();
String modelId = config.chatModel().modelId();
String spaceId = config.spaceId().orElse(null);
String projectId = config.projectId().orElse(null);

Expand All @@ -113,7 +113,7 @@ void check_chat_model_config() throws Exception {
@Test
void check_token_count_estimator() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.generationModel().modelId();
String modelId = config.chatModel().modelId();
String spaceId = config.spaceId().orElse(null);
String projectId = config.projectId().orElse(null);

Expand All @@ -130,7 +130,7 @@ void check_token_count_estimator() throws Exception {
@Test
void check_chat_streaming_model_config() throws Exception {
var config = langchain4jWatsonConfig.defaultConfig();
String modelId = config.generationModel().modelId();
String modelId = config.chatModel().modelId();
String spaceId = config.spaceId().orElse(null);
String projectId = config.projectId().orElse(null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ void handlerBeforeEach() {
@Test
void check_config() throws Exception {
var runtimeConfig = langchain4jWatsonConfig.defaultConfig();
assertEquals(WireMockUtil.URL_WATSONX_SERVER, runtimeConfig.baseUrl().toString());
assertEquals(WireMockUtil.URL_WATSONX_SERVER, runtimeConfig.baseUrl().orElse(null).toString());
assertEquals(WireMockUtil.URL_IAM_SERVER, runtimeConfig.iam().baseUrl().toString());
assertEquals(WireMockUtil.API_KEY, runtimeConfig.apiKey());
assertEquals(WireMockUtil.API_KEY, runtimeConfig.apiKey().orElse(null));
assertEquals("my-space-id", runtimeConfig.spaceId().orElse(null));
assertEquals(WireMockUtil.PROJECT_ID, runtimeConfig.projectId().orElse(null));
assertEquals(Duration.ofSeconds(60), runtimeConfig.timeout().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
@Recorder
public class WatsonxRecorder {

private static final String DUMMY_URL = "https://dummy.ai/api";
private static final String DUMMY_API_KEY = "dummy";
private static final Map<String, WatsonxTokenGenerator> tokenGeneratorCache = new HashMap<>();
private static final ConfigValidationException.Problem[] EMPTY_PROBLEMS = new ConfigValidationException.Problem[0];

Expand All @@ -47,7 +45,7 @@ public Supplier<ChatLanguageModel> chatModel(LangChain4jWatsonxConfig runtimeCon

if (watsonRuntimeConfig.enableIntegration()) {

var builder = chatBuilder(watsonRuntimeConfig, configName);
var builder = chatBuilder(runtimeConfig, configName);
return new Supplier<>() {
@Override
public ChatLanguageModel get() {
Expand All @@ -74,7 +72,7 @@ public Supplier<StreamingChatLanguageModel> streamingChatModel(LangChain4jWatson

if (watsonRuntimeConfig.enableIntegration()) {

var builder = chatBuilder(watsonRuntimeConfig, configName);
var builder = chatBuilder(runtimeConfig, configName);
return new Supplier<>() {
@Override
public StreamingChatLanguageModel get() {
Expand All @@ -101,7 +99,7 @@ public Supplier<ChatLanguageModel> generationModel(LangChain4jWatsonxConfig runt

if (watsonRuntimeConfig.enableIntegration()) {

var builder = generationBuilder(watsonRuntimeConfig, configName);
var builder = generationBuilder(runtimeConfig, configName);
return new Supplier<>() {
@Override
public ChatLanguageModel get() {
Expand All @@ -128,7 +126,7 @@ public Supplier<StreamingChatLanguageModel> generationStreamingModel(LangChain4j

if (watsonRuntimeConfig.enableIntegration()) {

var builder = generationBuilder(watsonRuntimeConfig, configName);
var builder = generationBuilder(runtimeConfig, configName);
return new Supplier<>() {
@Override
public StreamingChatLanguageModel get() {
Expand All @@ -152,19 +150,20 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jWatsonxConfig runtimeC
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName);

if (watsonConfig.enableIntegration()) {
var configProblems = checkConfigurations(watsonConfig, configName);
var configProblems = checkConfigurations(runtimeConfig, configName);

if (!configProblems.isEmpty()) {
throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS));
}

String iamUrl = watsonConfig.iam().baseUrl().toExternalForm();
WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl,
createTokenGenerator(watsonConfig.iam(), watsonConfig.apiKey()));
createTokenGenerator(watsonConfig.iam(),
firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey())));

URL url;
try {
url = new URL(watsonConfig.baseUrl());
url = new URL(firstOrDefault(null, watsonConfig.baseUrl(), runtimeConfig.defaultConfig().baseUrl()));
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -177,8 +176,8 @@ public Supplier<EmbeddingModel> embeddingModel(LangChain4jWatsonxConfig runtimeC
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), watsonConfig.logRequests()))
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), watsonConfig.logResponses()))
.version(watsonConfig.version())
.spaceId(watsonConfig.spaceId().orElse(null))
.projectId(watsonConfig.projectId().orElse(null))
.spaceId(firstOrDefault(null, watsonConfig.spaceId(), runtimeConfig.defaultConfig().spaceId()))
.projectId(firstOrDefault(null, watsonConfig.projectId(), runtimeConfig.defaultConfig().projectId()))
.modelId(embeddingModelConfig.modelId())
.truncateInputTokens(embeddingModelConfig.truncateInputTokens().orElse(null));

Expand All @@ -204,19 +203,20 @@ public EmbeddingModel get() {
public Supplier<ScoringModel> scoringModel(LangChain4jWatsonxConfig runtimeConfig, String configName) {
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName);

var configProblems = checkConfigurations(watsonConfig, configName);
var configProblems = checkConfigurations(runtimeConfig, configName);

if (!configProblems.isEmpty()) {
throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS));
}

String iamUrl = watsonConfig.iam().baseUrl().toExternalForm();
WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl,
createTokenGenerator(watsonConfig.iam(), watsonConfig.apiKey()));
createTokenGenerator(watsonConfig.iam(),
firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey())));

URL url;
try {
url = new URL(watsonConfig.baseUrl());
url = new URL(firstOrDefault(null, watsonConfig.baseUrl(), runtimeConfig.defaultConfig().baseUrl()));
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -229,8 +229,8 @@ public Supplier<ScoringModel> scoringModel(LangChain4jWatsonxConfig runtimeConfi
.logRequests(firstOrDefault(false, rerankModelConfig.logRequests(), watsonConfig.logRequests()))
.logResponses(firstOrDefault(false, rerankModelConfig.logResponses(), watsonConfig.logResponses()))
.version(watsonConfig.version())
.spaceId(watsonConfig.spaceId().orElse(null))
.projectId(watsonConfig.projectId().orElse(null))
.spaceId(firstOrDefault(null, watsonConfig.spaceId(), runtimeConfig.defaultConfig().spaceId()))
.projectId(firstOrDefault(null, watsonConfig.projectId(), runtimeConfig.defaultConfig().projectId()))
.modelId(rerankModelConfig.modelId())
.truncateInputTokens(rerankModelConfig.truncateInputTokens().orElse(null));

Expand All @@ -253,38 +253,38 @@ public WatsonxTokenGenerator apply(String iamUrl) {
};
}

private WatsonxChatModel.Builder chatBuilder(
LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig,
String configName) {
private WatsonxChatModel.Builder chatBuilder(LangChain4jWatsonxConfig runtimeConfig, String configName) {
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName);

ChatModelConfig chatModelConfig = watsonRuntimeConfig.chatModel();
var configProblems = checkConfigurations(watsonRuntimeConfig, configName);
var configProblems = checkConfigurations(runtimeConfig, configName);

if (!configProblems.isEmpty()) {
throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS));
}

String iamUrl = watsonRuntimeConfig.iam().baseUrl().toExternalForm();
ChatModelConfig chatModelConfig = watsonConfig.chatModel();
String iamUrl = watsonConfig.iam().baseUrl().toExternalForm();
WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl,
createTokenGenerator(watsonRuntimeConfig.iam(), watsonRuntimeConfig.apiKey()));
createTokenGenerator(watsonConfig.iam(),
firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey())));

URL url;
try {
url = new URL(watsonRuntimeConfig.baseUrl());
url = new URL(firstOrDefault(null, watsonConfig.baseUrl(), runtimeConfig.defaultConfig().baseUrl()));
} catch (Exception e) {
throw new RuntimeException(e);
}

return WatsonxChatModel.builder()
.tokenGenerator(tokenGenerator)
.url(url)
.timeout(watsonRuntimeConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), watsonRuntimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), watsonRuntimeConfig.logResponses()))
.version(watsonRuntimeConfig.version())
.spaceId(watsonRuntimeConfig.spaceId().orElse(null))
.projectId(watsonRuntimeConfig.projectId().orElse(null))
.modelId(watsonRuntimeConfig.generationModel().modelId())
.timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), watsonConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), watsonConfig.logResponses()))
.version(watsonConfig.version())
.spaceId(firstOrDefault(null, watsonConfig.spaceId(), runtimeConfig.defaultConfig().spaceId()))
.projectId(firstOrDefault(null, watsonConfig.projectId(), runtimeConfig.defaultConfig().projectId()))
.modelId(watsonConfig.chatModel().modelId())
.frequencyPenalty(chatModelConfig.frequencyPenalty())
.logprobs(chatModelConfig.logprobs())
.topLogprobs(chatModelConfig.topLogprobs().orElse(null))
Expand All @@ -296,24 +296,24 @@ private WatsonxChatModel.Builder chatBuilder(
.responseFormat(chatModelConfig.responseFormat().orElse(null));
}

private WatsonxGenerationModel.Builder generationBuilder(
LangChain4jWatsonxConfig.WatsonConfig watsonRuntimeConfig,
String configName) {
private WatsonxGenerationModel.Builder generationBuilder(LangChain4jWatsonxConfig runtimeConfig, String configName) {
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName);

GenerationModelConfig generationModelConfig = watsonRuntimeConfig.generationModel();
var configProblems = checkConfigurations(watsonRuntimeConfig, configName);
var configProblems = checkConfigurations(runtimeConfig, configName);

if (!configProblems.isEmpty()) {
throw new ConfigValidationException(configProblems.toArray(EMPTY_PROBLEMS));
}

String iamUrl = watsonRuntimeConfig.iam().baseUrl().toExternalForm();
GenerationModelConfig generationModelConfig = watsonConfig.generationModel();
String iamUrl = watsonConfig.iam().baseUrl().toExternalForm();
WatsonxTokenGenerator tokenGenerator = tokenGeneratorCache.computeIfAbsent(iamUrl,
createTokenGenerator(watsonRuntimeConfig.iam(), watsonRuntimeConfig.apiKey()));
createTokenGenerator(watsonConfig.iam(),
firstOrDefault(null, watsonConfig.apiKey(), runtimeConfig.defaultConfig().apiKey())));

URL url;
try {
url = new URL(watsonRuntimeConfig.baseUrl());
url = new URL(firstOrDefault(null, watsonConfig.baseUrl(), runtimeConfig.defaultConfig().baseUrl()));
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -325,13 +325,13 @@ private WatsonxGenerationModel.Builder generationBuilder(
return WatsonxGenerationModel.builder()
.tokenGenerator(tokenGenerator)
.url(url)
.timeout(watsonRuntimeConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(firstOrDefault(false, generationModelConfig.logRequests(), watsonRuntimeConfig.logRequests()))
.logResponses(firstOrDefault(false, generationModelConfig.logResponses(), watsonRuntimeConfig.logResponses()))
.version(watsonRuntimeConfig.version())
.spaceId(watsonRuntimeConfig.spaceId().orElse(null))
.projectId(watsonRuntimeConfig.projectId().orElse(null))
.modelId(watsonRuntimeConfig.generationModel().modelId())
.timeout(watsonConfig.timeout().orElse(Duration.ofSeconds(10)))
.logRequests(firstOrDefault(false, generationModelConfig.logRequests(), watsonConfig.logRequests()))
.logResponses(firstOrDefault(false, generationModelConfig.logResponses(), watsonConfig.logResponses()))
.version(watsonConfig.version())
.spaceId(firstOrDefault(null, watsonConfig.spaceId(), runtimeConfig.defaultConfig().spaceId()))
.projectId(firstOrDefault(null, watsonConfig.projectId(), runtimeConfig.defaultConfig().projectId()))
.modelId(watsonConfig.generationModel().modelId())
.decodingMethod(generationModelConfig.decodingMethod())
.decayFactor(decayFactor)
.startIndex(startIndex)
Expand Down Expand Up @@ -359,20 +359,21 @@ private LangChain4jWatsonxConfig.WatsonConfig correspondingWatsonRuntimeConfig(L
return watsonConfig;
}

private List<ConfigValidationException.Problem> checkConfigurations(LangChain4jWatsonxConfig.WatsonConfig watsonConfig,
private List<ConfigValidationException.Problem> checkConfigurations(LangChain4jWatsonxConfig runtimeConfig,
String configName) {
List<ConfigValidationException.Problem> configProblems = new ArrayList<>();
LangChain4jWatsonxConfig.WatsonConfig watsonConfig = correspondingWatsonRuntimeConfig(runtimeConfig, configName);

if (DUMMY_URL.equals(watsonConfig.baseUrl())) {
if (watsonConfig.baseUrl().isEmpty() && runtimeConfig.defaultConfig().baseUrl().isEmpty()) {
configProblems.add(createBaseURLConfigProblem(configName));
}
String apiKey = watsonConfig.apiKey();
if (DUMMY_API_KEY.equals(apiKey)) {
if (watsonConfig.apiKey().isEmpty() && runtimeConfig.defaultConfig().apiKey().isEmpty()) {
configProblems.add(createApiKeyConfigProblem(configName));
}
if (watsonConfig.projectId().isEmpty() && watsonConfig.spaceId().isEmpty()) {
if (watsonConfig.projectId().isEmpty() && runtimeConfig.defaultConfig().projectId().isEmpty() &&
watsonConfig.spaceId().isEmpty() && runtimeConfig.defaultConfig().spaceId().isEmpty()) {
var config = NamedConfigUtil.isDefault(configName) ? "." : ("." + configName + ".");
var errorMessage = "One of the two properties quarkus.langchain4j.watsonx%s%s / quarkus.langchain4j.watsonx%s%s is required, but could not be found in any config source";
var errorMessage = "One of the properties quarkus.langchain4j.watsonx%s%s / quarkus.langchain4j.watsonx%s%s is required, but could not be found in any config source";
configProblems.add(new ConfigValidationException.Problem(
String.format(errorMessage, config, "project-id", config, "space-id")));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ interface WatsonConfig {
/**
* Base URL of the watsonx.ai API.
*/
@WithDefault("https://dummy.ai/api")
String baseUrl();
Optional<String> baseUrl();

/**
* IBM Cloud API key.
* <p>
* To create a new API key, follow this <a href="https://cloud.ibm.com/iam/apikeys">link</a>.
*/
@WithDefault("dummy")
String apiKey();
Optional<String> apiKey();

/**
* Timeout for watsonx.ai calls.
Expand Down

0 comments on commit 4fe7711

Please sign in to comment.