Skip to content

Commit

Permalink
Merge pull request #200 from quarkiverse/#183
Browse files Browse the repository at this point in the history
Allows for specific model config options to override client-wide logging config
  • Loading branch information
geoand authored Jan 3, 2024
2 parents a67a589 + 8370674 commit fd2dd81
Show file tree
Hide file tree
Showing 14 changed files with 168 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.runtime;

import java.util.Optional;

public class OptionalUtil {

@SafeVarargs
public static <T> T firstOrDefault(T defaultValue, Optional<T>... values) {
for (Optional<T> o : values) {
if (o != null && o.isPresent()) {
return o.get();
}
}
return defaultValue;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.langchain4j.huggingface.runtime;

import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault;

import java.net.URL;
import java.util.Optional;
import java.util.function.Supplier;
Expand Down Expand Up @@ -32,8 +34,8 @@ public Supplier<?> chatModel(Langchain4jHuggingFaceConfig runtimeConfig) {
.topP(chatModelConfig.topP())
.topK(chatModelConfig.topK())
.repetitionPenalty(chatModelConfig.repetitionPenalty())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses());
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()));

if (apiKeyOpt.isPresent()) {
builder.accessToken(apiKeyOpt.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.OptionalDouble;
import java.util.OptionalInt;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;
import io.smallrye.config.WithDefault;

Expand Down Expand Up @@ -75,4 +76,16 @@ public interface ChatModelConfig {
*/
OptionalDouble repetitionPenalty();

/**
* Whether chat model requests should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether chat model responses should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logResponses();

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.time.Duration;
import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;
import io.smallrye.config.WithDefault;
Expand Down Expand Up @@ -35,14 +36,14 @@ public interface Langchain4jHuggingFaceConfig {
EmbeddingModelConfig embeddingModel();

/**
* Whether the OpenAI client should log requests
* Whether the HuggingFace client should log requests
*/
@WithDefault("false")
Boolean logRequests();
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether the OpenAI client should log responses
* Whether the HuggingFace client should log responses
*/
@WithDefault("false")
Boolean logResponses();
@ConfigDocDefault("false")
Optional<Boolean> logResponses();
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package io.quarkiverse.langchain4j.azure.openai.runtime;

import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault;

import java.util.function.Supplier;

import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiEmbeddingModel;
import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig;
import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig;
import io.quarkiverse.langchain4j.openai.QuarkusOpenAiClient;
import io.quarkus.runtime.ShutdownContext;
Expand All @@ -22,8 +25,8 @@ public Supplier<?> chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
.apiVersion(runtimeConfig.apiVersion())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()))

.temperature(chatModelConfig.temperature())
.topP(chatModelConfig.topP())
Expand All @@ -48,8 +51,8 @@ public Supplier<?> streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig
.baseUrl(getBaseUrl(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()))

.temperature(chatModelConfig.temperature())
.topP(chatModelConfig.topP())
Expand All @@ -69,13 +72,14 @@ public Object get() {
}

public Supplier<?> embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) {
EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel();
var builder = AzureOpenAiEmbeddingModel.builder()
.baseUrl(getBaseUrl(runtimeConfig))
.apiKey(runtimeConfig.apiKey())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses());
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses()));

return new Supplier<>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;
import io.smallrye.config.WithDefault;

Expand Down Expand Up @@ -49,4 +50,16 @@ public interface ChatModelConfig {
*/
@WithDefault("0")
Double frequencyPenalty();

/**
* Whether chat model requests should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether chat model responses should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logResponses();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.quarkiverse.langchain4j.azure.openai.runtime.config;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;

@ConfigGroup
public interface EmbeddingModelConfig {

/**
* Whether embedding model requests should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether embedding model responses should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logResponses();
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME;

import java.time.Duration;
import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;
import io.smallrye.config.WithDefault;
Expand Down Expand Up @@ -55,17 +57,22 @@ public interface Langchain4jAzureOpenAiConfig {
/**
* Whether the OpenAI client should log requests
*/
@WithDefault("false")
Boolean logRequests();
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether the OpenAI client should log responses
*/
@WithDefault("false")
Boolean logResponses();
@ConfigDocDefault("false")
Optional<Boolean> logResponses();

/**
* Chat model related settings
*/
ChatModelConfig chatModel();

/**
* Embedding model related settings
*/
EmbeddingModelConfig embeddingModel();
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.langchain4j.openai.runtime;

import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;
Expand Down Expand Up @@ -34,8 +36,8 @@ public Supplier<?> chatModel(Langchain4jOpenAiConfig runtimeConfig) {
.apiKey(apiKeyOpt.get())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(chatModelConfig.modelName())
.temperature(chatModelConfig.temperature())
Expand Down Expand Up @@ -65,8 +67,8 @@ public Supplier<?> streamingChatModel(Langchain4jOpenAiConfig runtimeConfig) {
.baseUrl(runtimeConfig.baseUrl())
.apiKey(apiKeyOpt.get())
.timeout(runtimeConfig.timeout())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(chatModelConfig.modelName())
.temperature(chatModelConfig.temperature())
Expand Down Expand Up @@ -97,8 +99,8 @@ public Supplier<?> embeddingModel(Langchain4jOpenAiConfig runtimeConfig) {
.apiKey(apiKeyOpt.get())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(embeddingModelConfig.modelName());

Expand All @@ -121,8 +123,8 @@ public Supplier<?> moderationModel(Langchain4jOpenAiConfig runtimeConfig) {
.apiKey(apiKeyOpt.get())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, moderationModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(moderationModelConfig.modelName());

Expand All @@ -145,8 +147,8 @@ public Supplier<?> imageModel(Langchain4jOpenAiConfig runtimeConfig) {
.apiKey(apiKeyOpt.get())
.timeout(runtimeConfig.timeout())
.maxRetries(runtimeConfig.maxRetries())
.logRequests(runtimeConfig.logRequests())
.logResponses(runtimeConfig.logResponses())
.logRequests(firstOrDefault(false, imageModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, imageModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(imageModelConfig.modelName())
.size(imageModelConfig.size())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;
import io.smallrye.config.WithDefault;

Expand Down Expand Up @@ -55,4 +56,16 @@ public interface ChatModelConfig {
*/
@WithDefault("0")
Double frequencyPenalty();

/**
* Whether chat model requests should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether chat model responses should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logResponses();
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package io.quarkiverse.langchain4j.openai.runtime.config;

import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigGroup;
import io.smallrye.config.WithDefault;

Expand All @@ -11,4 +14,16 @@ public interface EmbeddingModelConfig {
*/
@WithDefault("text-embedding-ada-002")
String modelName();

/**
* Whether embedding model requests should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether embedding model responses should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logResponses();
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,16 @@ public interface ImageModelConfig {
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
*/
Optional<String> user();

/**
* Whether image model requests should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether image model responses should be logged
*/
@ConfigDocDefault("false")
Optional<Boolean> logResponses();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.time.Duration;
import java.util.Optional;

import io.quarkus.runtime.annotations.ConfigDocDefault;
import io.quarkus.runtime.annotations.ConfigRoot;
import io.smallrye.config.ConfigMapping;
import io.smallrye.config.WithDefault;
Expand Down Expand Up @@ -39,14 +40,14 @@ public interface Langchain4jOpenAiConfig {
/**
* Whether the OpenAI client should log requests
*/
@WithDefault("false")
Boolean logRequests();
@ConfigDocDefault("false")
Optional<Boolean> logRequests();

/**
* Whether the OpenAI client should log responses
*/
@WithDefault("false")
Boolean logResponses();
@ConfigDocDefault("false")
Optional<Boolean> logResponses();

/**
* Chat model related settings
Expand Down
Loading

0 comments on commit fd2dd81

Please sign in to comment.