Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keys to override config for chat, embedding and image-model for Azure OpenAI #1177

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ release.properties
# Quarkus CLI
.quarkus

# dotenv
.env

#Dolphin
.directory
/samples/chatbot/dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiImageModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig.AzureAiConfig.EndpointType;
import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient;
Expand Down Expand Up @@ -58,17 +57,16 @@ public Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
ChatModelConfig chatModelConfig = azureAiConfig.chatModel();
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);

var chatModelConfig = azureAiConfig.chatModel();
var apiKey = firstOrDefault(null, chatModelConfig.apiKey(), azureAiConfig.apiKey());
var adToken = firstOrDefault(null, chatModelConfig.adToken(), azureAiConfig.adToken());
var builder = AzureOpenAiChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiKey(apiKey)
.adToken(adToken)
// .tokenizer(new OpenAiTokenizer("<modelName>")) TODO: Set the tokenizer, it is always null!!
.apiVersion(azureAiConfig.apiVersion())
.apiVersion(chatModelConfig.apiVersion().orElse(azureAiConfig.apiVersion()))
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down Expand Up @@ -158,15 +156,15 @@ public Function<SyntheticCreationalContext<EmbeddingModel>, EmbeddingModel> embe
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel();
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);
var embeddingModelConfig = azureAiConfig.embeddingModel();
var apiKey = firstOrDefault(null, embeddingModelConfig.apiKey(), azureAiConfig.apiKey());
var adToken = firstOrDefault(null, embeddingModelConfig.adToken(), azureAiConfig.adToken());
var builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.apiVersion(embeddingModelConfig.apiVersion().orElse(azureAiConfig.apiVersion()))
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down Expand Up @@ -195,15 +193,14 @@ public Function<SyntheticCreationalContext<ImageModel>, ImageModel> imageModel(L
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
var apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);

var imageModelConfig = azureAiConfig.imageModel();
var apiKey = firstOrDefault(null, imageModelConfig.apiKey(), azureAiConfig.apiKey());
var adToken = firstOrDefault(null, imageModelConfig.adToken(), azureAiConfig.adToken());
var builder = AzureOpenAiImageModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.IMAGE))
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
.apiVersion(imageModelConfig.apiVersion().orElse(azureAiConfig.apiVersion()))
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
.logRequests(firstOrDefault(false, imageModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ public interface ChatModelConfig {
@WithDefault(ConfigConstants.DUMMY_VALUE)
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Optional<String> apiVersion();

/**
* Azure OpenAI API key
*/
Optional<String> apiKey();

/**
* What sampling temperature to use, with values between 0 and 2.
* Higher values means the model will take more risks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ public interface EmbeddingModelConfig {
*/
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Optional<String> apiVersion();

/**
* Azure OpenAI API key
*/
Optional<String> apiKey();

/**
* Whether embedding model requests should be logged
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ public interface ImageModelConfig {
*/
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Optional<String> apiVersion();

/**
* Azure OpenAI API key
*/
Optional<String> apiKey();

/**
* Model name to use
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public Optional<String> apiVersion() {
return Optional.empty();
}

@Override
public Optional<String> apiKey() {
return Optional.empty();
}

@Override
public Double temperature() {
return null;
Expand Down Expand Up @@ -285,6 +300,21 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public Optional<String> apiVersion() {
return Optional.empty();
}

@Override
public Optional<String> apiKey() {
return Optional.empty();
}

@Override
public Optional<Boolean> logRequests() {
return Optional.empty();
Expand All @@ -294,6 +324,7 @@ public Optional<Boolean> logRequests() {
public Optional<Boolean> logResponses() {
return Optional.empty();
}

};
}

Expand All @@ -320,6 +351,21 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public Optional<String> apiVersion() {
return Optional.empty();
}

@Override
public Optional<String> apiKey() {
return Optional.empty();
}

@Override
public String modelName() {
return null;
Expand Down
Loading