From 009104f95a0cfdfe61cd94901b6bf8057a956fbe Mon Sep 17 00:00:00 2001 From: Ioannis Canellos Date: Tue, 17 Dec 2024 09:46:07 +0200 Subject: [PATCH] fix: don't require api key when using custom base url for mistralai --- .../mistralai/runtime/MistralAiRecorder.java | 29 +++++++++++-------- .../config/LangChain4jMistralAiConfig.java | 8 +++-- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java b/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java index bd4a07020..0f0a013fb 100644 --- a/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java +++ b/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/MistralAiRecorder.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.mistralai.runtime; +import static io.quarkiverse.langchain4j.mistralai.runtime.config.LangChain4jMistralAiConfig.MistralAiConfig.DEFAULT_API_KEY; +import static io.quarkiverse.langchain4j.mistralai.runtime.config.LangChain4jMistralAiConfig.MistralAiConfig.DEFAULT_BASE_URL; import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; import java.time.Duration; @@ -27,22 +29,22 @@ @Recorder public class MistralAiRecorder { - private static final String DUMMY_KEY = "dummy"; public Supplier chatModel(LangChain4jMistralAiConfig runtimeConfig, String configName) { LangChain4jMistralAiConfig.MistralAiConfig mistralAiConfig = correspondingMistralAiConfig(runtimeConfig, configName); if (mistralAiConfig.enableIntegration()) { - String apiKey = mistralAiConfig.apiKey(); ChatModelConfig chatModelConfig = mistralAiConfig.chatModel(); - if (DUMMY_KEY.equals(apiKey)) { + String apiKey = mistralAiConfig.apiKey(); + String baseUrl = mistralAiConfig.baseUrl(); + if (DEFAULT_API_KEY.equals(apiKey) && DEFAULT_BASE_URL.equals(baseUrl)) { throw new ConfigValidationException(createApiKeyConfigProblem(configName)); } var builder = MistralAiChatModel.builder() - .baseUrl(mistralAiConfig.baseUrl()) + .baseUrl(baseUrl) .apiKey(apiKey) .modelName(chatModelConfig.modelName()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), mistralAiConfig.logRequests())) @@ -87,15 +89,16 @@ public Supplier streamingChatModel(LangChain4jMistra configName); if (mistralAiConfig.enableIntegration()) { - String apiKey = mistralAiConfig.apiKey(); ChatModelConfig chatModelConfig = mistralAiConfig.chatModel(); - if (DUMMY_KEY.equals(apiKey)) { + String apiKey = mistralAiConfig.apiKey(); + String baseUrl = mistralAiConfig.baseUrl(); + if (DEFAULT_API_KEY.equals(apiKey) && DEFAULT_BASE_URL.equals(baseUrl)) { throw new ConfigValidationException(createApiKeyConfigProblem(configName)); } var builder = MistralAiStreamingChatModel.builder() - .baseUrl(mistralAiConfig.baseUrl()) + .baseUrl(baseUrl) .apiKey(apiKey) .modelName(chatModelConfig.modelName()) .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), mistralAiConfig.logRequests())) @@ -139,10 +142,11 @@ public Supplier embeddingModel(LangChain4jMistralAiConfig runtim configName); if (mistralAiConfig.enableIntegration()) { - String apiKey = mistralAiConfig.apiKey(); EmbeddingModelConfig embeddingModelConfig = mistralAiConfig.embeddingModel(); - if (DUMMY_KEY.equals(apiKey)) { + String apiKey = mistralAiConfig.apiKey(); + String baseUrl = mistralAiConfig.baseUrl(); + if (DEFAULT_API_KEY.equals(apiKey) && DEFAULT_BASE_URL.equals(baseUrl)) { throw new ConfigValidationException(createApiKeyConfigProblem(configName)); } @@ -175,15 +179,16 @@ public Supplier moderationModel(LangChain4jMistralAiConfig runt configName); if (mistralAiConfig.enableIntegration()) { - String apiKey = mistralAiConfig.apiKey(); ModerationModelConfig moderationModelConfig = mistralAiConfig.moderationModel(); - if (DUMMY_KEY.equals(apiKey)) { + String apiKey = mistralAiConfig.apiKey(); + String baseUrl = mistralAiConfig.baseUrl(); + if (DEFAULT_API_KEY.equals(apiKey) && DEFAULT_BASE_URL.equals(baseUrl)) { throw new ConfigValidationException(createApiKeyConfigProblem(configName)); } var builder = new MistralAiModerationModel.Builder() - .baseUrl(mistralAiConfig.baseUrl()) + .baseUrl(baseUrl) .apiKey(apiKey) .modelName(moderationModelConfig.modelName()) .logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), mistralAiConfig.logRequests())) diff --git a/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java b/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java index 4e610a099..22a8b112b 100644 --- a/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java +++ b/model-providers/mistral/runtime/src/main/java/io/quarkiverse/langchain4j/mistralai/runtime/config/LangChain4jMistralAiConfig.java @@ -37,16 +37,20 @@ public interface LangChain4jMistralAiConfig { @ConfigGroup interface MistralAiConfig { + + String DEFAULT_BASE_URL = "https://api.mistral.ai/v1/"; + String DEFAULT_API_KEY = "dummy"; + /** * Base URL of Mistral API */ - @WithDefault("https://api.mistral.ai/v1/") + @WithDefault(DEFAULT_BASE_URL) String baseUrl(); /** * Mistral API key */ - @WithDefault("dummy") // TODO: this should be optional but Smallrye Config doesn't like it + @WithDefault(DEFAULT_API_KEY) String apiKey(); /**