Skip to content

Commit

Permalink
Merge pull request #1177 from flyinfish/bugfix/1154
Browse files Browse the repository at this point in the history
Add keys to override config for chat, embedding and image-model for Azure OpenAI
  • Loading branch information
geoand authored Dec 31, 2024
2 parents 547241c + 89b947d commit c7cff12
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ release.properties
# Quarkus CLI
.quarkus

# dotenv
.env

#Dolphin
.directory
/samples/chatbot/dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiImageModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.LangChain4jAzureOpenAiConfig.AzureAiConfig.EndpointType;
import io.quarkiverse.langchain4j.openai.common.QuarkusOpenAiClient;
Expand Down Expand Up @@ -58,17 +57,16 @@ public Function<SyntheticCreationalContext<ChatLanguageModel>, ChatLanguageModel
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
ChatModelConfig chatModelConfig = azureAiConfig.chatModel();
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);

var chatModelConfig = azureAiConfig.chatModel();
var apiKey = firstOrDefault(null, chatModelConfig.apiKey(), azureAiConfig.apiKey());
var adToken = firstOrDefault(null, chatModelConfig.adToken(), azureAiConfig.adToken());
var builder = AzureOpenAiChatModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.CHAT))
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiKey(apiKey)
.adToken(adToken)
// .tokenizer(new OpenAiTokenizer("<modelName>")) TODO: Set the tokenizer, it is always null!!
.apiVersion(azureAiConfig.apiVersion())
.apiVersion(chatModelConfig.apiVersion().orElse(azureAiConfig.apiVersion()))
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down Expand Up @@ -158,15 +156,15 @@ public Function<SyntheticCreationalContext<EmbeddingModel>, EmbeddingModel> embe
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
EmbeddingModelConfig embeddingModelConfig = azureAiConfig.embeddingModel();
String apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);
var embeddingModelConfig = azureAiConfig.embeddingModel();
var apiKey = firstOrDefault(null, embeddingModelConfig.apiKey(), azureAiConfig.apiKey());
var adToken = firstOrDefault(null, embeddingModelConfig.adToken(), azureAiConfig.adToken());
var builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.EMBEDDING))
.apiKey(apiKey)
.adToken(adToken)
.configName(NamedConfigUtil.isDefault(configName) ? null : configName)
.apiVersion(azureAiConfig.apiVersion())
.apiVersion(embeddingModelConfig.apiVersion().orElse(azureAiConfig.apiVersion()))
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down Expand Up @@ -195,15 +193,14 @@ public Function<SyntheticCreationalContext<ImageModel>, ImageModel> imageModel(L
LangChain4jAzureOpenAiConfig.AzureAiConfig azureAiConfig = correspondingAzureOpenAiConfig(runtimeConfig, configName);

if (azureAiConfig.enableIntegration()) {
var apiKey = azureAiConfig.apiKey().orElse(null);
String adToken = azureAiConfig.adToken().orElse(null);

var imageModelConfig = azureAiConfig.imageModel();
var apiKey = firstOrDefault(null, imageModelConfig.apiKey(), azureAiConfig.apiKey());
var adToken = firstOrDefault(null, imageModelConfig.adToken(), azureAiConfig.adToken());
var builder = AzureOpenAiImageModel.builder()
.endpoint(getEndpoint(azureAiConfig, configName, EndpointType.IMAGE))
.apiKey(apiKey)
.adToken(adToken)
.apiVersion(azureAiConfig.apiVersion())
.apiVersion(imageModelConfig.apiVersion().orElse(azureAiConfig.apiVersion()))
.timeout(azureAiConfig.timeout().orElse(Duration.ofSeconds(10)))
.maxRetries(azureAiConfig.maxRetries())
.logRequests(firstOrDefault(false, imageModelConfig.logRequests(), azureAiConfig.logRequests()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ public interface ChatModelConfig {
@WithDefault(ConfigConstants.DUMMY_VALUE)
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Optional<String> apiVersion();

/**
* Azure OpenAI API key
*/
Optional<String> apiKey();

/**
* What sampling temperature to use, with values between 0 and 2.
* Higher values means the model will take more risks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ public interface EmbeddingModelConfig {
*/
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Optional<String> apiVersion();

/**
* Azure OpenAI API key
*/
Optional<String> apiKey();

/**
* Whether embedding model requests should be logged
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ public interface ImageModelConfig {
*/
Optional<String> endpoint();

/**
* The Azure AD token to use for this operation.
* If present, then the requests towards OpenAI will include this in the Authorization header.
* Note that this property overrides the functionality of {@code quarkus.langchain4j.azure-openai.embedding-model.api-key}.
*/
Optional<String> adToken();

/**
* The API version to use for this operation. This follows the YYYY-MM-DD format
*/
Optional<String> apiVersion();

/**
* Azure OpenAI API key
*/
Optional<String> apiKey();

/**
* Model name to use
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,21 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public Optional<String> apiVersion() {
return Optional.empty();
}

@Override
public Optional<String> apiKey() {
return Optional.empty();
}

@Override
public Double temperature() {
return null;
Expand Down Expand Up @@ -285,6 +300,21 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public Optional<String> apiVersion() {
return Optional.empty();
}

@Override
public Optional<String> apiKey() {
return Optional.empty();
}

@Override
public Optional<Boolean> logRequests() {
return Optional.empty();
Expand All @@ -294,6 +324,7 @@ public Optional<Boolean> logRequests() {
public Optional<Boolean> logResponses() {
return Optional.empty();
}

};
}

Expand All @@ -320,6 +351,21 @@ public Optional<String> endpoint() {
return Optional.empty();
}

@Override
public Optional<String> adToken() {
return Optional.empty();
}

@Override
public Optional<String> apiVersion() {
return Optional.empty();
}

@Override
public Optional<String> apiKey() {
return Optional.empty();
}

@Override
public String modelName() {
return null;
Expand Down

0 comments on commit c7cff12

Please sign in to comment.