From 837067459876f9ed4fb41eefc1f007a408f89b91 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Wed, 3 Jan 2024 14:49:13 +0200 Subject: [PATCH] Allows for specific model config options to override client-wide logging config Closes: #183 --- .../langchain4j/runtime/OptionalUtil.java | 16 ++++++++++++++ .../runtime/HuggingFaceRecorder.java | 6 +++-- .../runtime/config/ChatModelConfig.java | 13 +++++++++++ .../config/Langchain4jHuggingFaceConfig.java | 13 ++++++----- .../openai/runtime/AzureOpenAiRecorder.java | 16 +++++++++----- .../runtime/config/ChatModelConfig.java | 13 +++++++++++ .../runtime/config/EmbeddingModelConfig.java | 22 +++++++++++++++++++ .../config/Langchain4jAzureOpenAiConfig.java | 15 +++++++++---- .../openai/runtime/OpenAiRecorder.java | 22 ++++++++++--------- .../runtime/config/ChatModelConfig.java | 13 +++++++++++ .../runtime/config/EmbeddingModelConfig.java | 15 +++++++++++++ .../runtime/config/ImageModelConfig.java | 12 ++++++++++ .../config/Langchain4jOpenAiConfig.java | 9 ++++---- .../runtime/config/ModerationModelConfig.java | 15 +++++++++++++ 14 files changed, 168 insertions(+), 32 deletions(-) create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/OptionalUtil.java create mode 100644 openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/OptionalUtil.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/OptionalUtil.java new file mode 100644 index 000000000..94be8cc3b --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/OptionalUtil.java @@ -0,0 +1,16 @@ +package io.quarkiverse.langchain4j.runtime; + +import java.util.Optional; + +public class OptionalUtil { + + @SafeVarargs + public static T firstOrDefault(T defaultValue, Optional... values) { + for (Optional o : values) { + if (o != null && o.isPresent()) { + return o.get(); + } + } + return defaultValue; + } +} diff --git a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java index fe24fce6b..644a1e713 100644 --- a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java +++ b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/HuggingFaceRecorder.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.huggingface.runtime; +import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; + import java.net.URL; import java.util.Optional; import java.util.function.Supplier; @@ -32,8 +34,8 @@ public Supplier chatModel(Langchain4jHuggingFaceConfig runtimeConfig) { .topP(chatModelConfig.topP()) .topK(chatModelConfig.topK()) .repetitionPenalty(chatModelConfig.repetitionPenalty()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()); + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())); if (apiKeyOpt.isPresent()) { builder.accessToken(apiKeyOpt.get()); diff --git a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/ChatModelConfig.java b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/ChatModelConfig.java index ff797e8f0..0b74cc41d 100644 --- a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/ChatModelConfig.java +++ b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/ChatModelConfig.java @@ -5,6 +5,7 @@ import java.util.OptionalDouble; import java.util.OptionalInt; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigGroup; import io.smallrye.config.WithDefault; @@ -75,4 +76,16 @@ public interface ChatModelConfig { */ OptionalDouble repetitionPenalty(); + /** + * Whether chat model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether chat model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); + } diff --git a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java index d830b8538..7813b7d54 100644 --- a/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java +++ b/hugging-face/runtime/src/main/java/io/quarkiverse/langchain4j/huggingface/runtime/config/Langchain4jHuggingFaceConfig.java @@ -5,6 +5,7 @@ import java.time.Duration; import java.util.Optional; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; @@ -35,14 +36,14 @@ public interface Langchain4jHuggingFaceConfig { EmbeddingModelConfig embeddingModel(); /** - * Whether the OpenAI client should log requests + * Whether the HuggingFace client should log requests */ - @WithDefault("false") - Boolean logRequests(); + @ConfigDocDefault("false") + Optional logRequests(); /** - * Whether the OpenAI client should log responses + * Whether the HuggingFace client should log responses */ - @WithDefault("false") - Boolean logResponses(); + @ConfigDocDefault("false") + Optional logResponses(); } diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java index 8e70273e4..1f4181642 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/AzureOpenAiRecorder.java @@ -1,11 +1,14 @@ package io.quarkiverse.langchain4j.azure.openai.runtime; +import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; + import java.util.function.Supplier; import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiChatModel; import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiEmbeddingModel; import io.quarkiverse.langchain4j.azure.openai.AzureOpenAiStreamingChatModel; import io.quarkiverse.langchain4j.azure.openai.runtime.config.ChatModelConfig; +import io.quarkiverse.langchain4j.azure.openai.runtime.config.EmbeddingModelConfig; import io.quarkiverse.langchain4j.azure.openai.runtime.config.Langchain4jAzureOpenAiConfig; import io.quarkiverse.langchain4j.openai.QuarkusOpenAiClient; import io.quarkus.runtime.ShutdownContext; @@ -22,8 +25,8 @@ public Supplier chatModel(Langchain4jAzureOpenAiConfig runtimeConfig) { .apiVersion(runtimeConfig.apiVersion()) .timeout(runtimeConfig.timeout()) .maxRetries(runtimeConfig.maxRetries()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) @@ -48,8 +51,8 @@ public Supplier streamingChatModel(Langchain4jAzureOpenAiConfig runtimeConfig .baseUrl(getBaseUrl(runtimeConfig)) .apiKey(runtimeConfig.apiKey()) .timeout(runtimeConfig.timeout()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) .temperature(chatModelConfig.temperature()) .topP(chatModelConfig.topP()) @@ -69,13 +72,14 @@ public Object get() { } public Supplier embeddingModel(Langchain4jAzureOpenAiConfig runtimeConfig) { + EmbeddingModelConfig embeddingModelConfig = runtimeConfig.embeddingModel(); var builder = AzureOpenAiEmbeddingModel.builder() .baseUrl(getBaseUrl(runtimeConfig)) .apiKey(runtimeConfig.apiKey()) .timeout(runtimeConfig.timeout()) .maxRetries(runtimeConfig.maxRetries()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()); + .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses())); return new Supplier<>() { @Override diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java index a512a9f40..e59ef8ab6 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/ChatModelConfig.java @@ -2,6 +2,7 @@ import java.util.Optional; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigGroup; import io.smallrye.config.WithDefault; @@ -49,4 +50,16 @@ public interface ChatModelConfig { */ @WithDefault("0") Double frequencyPenalty(); + + /** + * Whether chat model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether chat model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); } diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java new file mode 100644 index 000000000..f25c0367b --- /dev/null +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/EmbeddingModelConfig.java @@ -0,0 +1,22 @@ +package io.quarkiverse.langchain4j.azure.openai.runtime.config; + +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; +import io.quarkus.runtime.annotations.ConfigGroup; + +@ConfigGroup +public interface EmbeddingModelConfig { + + /** + * Whether embedding model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether embedding model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); +} diff --git a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java index 569847c45..0ebc3b692 100644 --- a/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java +++ b/openai/azure-openai/runtime/src/main/java/io/quarkiverse/langchain4j/azure/openai/runtime/config/Langchain4jAzureOpenAiConfig.java @@ -3,7 +3,9 @@ import static io.quarkus.runtime.annotations.ConfigPhase.RUN_TIME; import java.time.Duration; +import java.util.Optional; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; @@ -55,17 +57,22 @@ public interface Langchain4jAzureOpenAiConfig { /** * Whether the OpenAI client should log requests */ - @WithDefault("false") - Boolean logRequests(); + @ConfigDocDefault("false") + Optional logRequests(); /** * Whether the OpenAI client should log responses */ - @WithDefault("false") - Boolean logResponses(); + @ConfigDocDefault("false") + Optional logResponses(); /** * Chat model related settings */ ChatModelConfig chatModel(); + + /** + * Embedding model related settings + */ + EmbeddingModelConfig embeddingModel(); } diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java index d9a6534c7..346372cd9 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.openai.runtime; +import static io.quarkiverse.langchain4j.runtime.OptionalUtil.firstOrDefault; + import java.nio.file.Path; import java.nio.file.Paths; import java.util.Optional; @@ -34,8 +36,8 @@ public Supplier chatModel(Langchain4jOpenAiConfig runtimeConfig) { .apiKey(apiKeyOpt.get()) .timeout(runtimeConfig.timeout()) .maxRetries(runtimeConfig.maxRetries()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) .modelName(chatModelConfig.modelName()) .temperature(chatModelConfig.temperature()) @@ -65,8 +67,8 @@ public Supplier streamingChatModel(Langchain4jOpenAiConfig runtimeConfig) { .baseUrl(runtimeConfig.baseUrl()) .apiKey(apiKeyOpt.get()) .timeout(runtimeConfig.timeout()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses())) .modelName(chatModelConfig.modelName()) .temperature(chatModelConfig.temperature()) @@ -97,8 +99,8 @@ public Supplier embeddingModel(Langchain4jOpenAiConfig runtimeConfig) { .apiKey(apiKeyOpt.get()) .timeout(runtimeConfig.timeout()) .maxRetries(runtimeConfig.maxRetries()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses())) .modelName(embeddingModelConfig.modelName()); @@ -121,8 +123,8 @@ public Supplier moderationModel(Langchain4jOpenAiConfig runtimeConfig) { .apiKey(apiKeyOpt.get()) .timeout(runtimeConfig.timeout()) .maxRetries(runtimeConfig.maxRetries()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, moderationModelConfig.logResponses(), runtimeConfig.logResponses())) .modelName(moderationModelConfig.modelName()); @@ -145,8 +147,8 @@ public Supplier imageModel(Langchain4jOpenAiConfig runtimeConfig) { .apiKey(apiKeyOpt.get()) .timeout(runtimeConfig.timeout()) .maxRetries(runtimeConfig.maxRetries()) - .logRequests(runtimeConfig.logRequests()) - .logResponses(runtimeConfig.logResponses()) + .logRequests(firstOrDefault(false, imageModelConfig.logRequests(), runtimeConfig.logRequests())) + .logResponses(firstOrDefault(false, imageModelConfig.logResponses(), runtimeConfig.logResponses())) .modelName(imageModelConfig.modelName()) .size(imageModelConfig.size()) diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java index 54c9283dc..87182dfbb 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java @@ -2,6 +2,7 @@ import java.util.Optional; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigGroup; import io.smallrye.config.WithDefault; @@ -55,4 +56,16 @@ public interface ChatModelConfig { */ @WithDefault("0") Double frequencyPenalty(); + + /** + * Whether chat model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether chat model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); } diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/EmbeddingModelConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/EmbeddingModelConfig.java index c7762f48f..ec2b06ad9 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/EmbeddingModelConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/EmbeddingModelConfig.java @@ -1,5 +1,8 @@ package io.quarkiverse.langchain4j.openai.runtime.config; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigGroup; import io.smallrye.config.WithDefault; @@ -11,4 +14,16 @@ public interface EmbeddingModelConfig { */ @WithDefault("text-embedding-ada-002") String modelName(); + + /** + * Whether embedding model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether embedding model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); } diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ImageModelConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ImageModelConfig.java index 0720f3ea2..fe0a40302 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ImageModelConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ImageModelConfig.java @@ -84,4 +84,16 @@ public interface ImageModelConfig { * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. */ Optional user(); + + /** + * Whether image model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether image model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); } diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java index ef694a6f2..dfee8791c 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/Langchain4jOpenAiConfig.java @@ -5,6 +5,7 @@ import java.time.Duration; import java.util.Optional; +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; @@ -39,14 +40,14 @@ public interface Langchain4jOpenAiConfig { /** * Whether the OpenAI client should log requests */ - @WithDefault("false") - Boolean logRequests(); + @ConfigDocDefault("false") + Optional logRequests(); /** * Whether the OpenAI client should log responses */ - @WithDefault("false") - Boolean logResponses(); + @ConfigDocDefault("false") + Optional logResponses(); /** * Chat model related settings diff --git a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ModerationModelConfig.java b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ModerationModelConfig.java index 27b2b99f6..ad3579fdf 100644 --- a/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ModerationModelConfig.java +++ b/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ModerationModelConfig.java @@ -1,5 +1,8 @@ package io.quarkiverse.langchain4j.openai.runtime.config; +import java.util.Optional; + +import io.quarkus.runtime.annotations.ConfigDocDefault; import io.quarkus.runtime.annotations.ConfigGroup; import io.smallrye.config.WithDefault; @@ -11,4 +14,16 @@ public interface ModerationModelConfig { */ @WithDefault("text-moderation-latest") String modelName(); + + /** + * Whether moderation model requests should be logged + */ + @ConfigDocDefault("false") + Optional logRequests(); + + /** + * Whether moderation model responses should be logged + */ + @ConfigDocDefault("false") + Optional logResponses(); }