Skip to content

Commit

Permalink
Allow specifying the organizationId in the configuration
Browse files Browse the repository at this point in the history
Fixes #131

# Conflicts:
#	openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java
  • Loading branch information
edeandrea committed Jan 3, 2024
1 parent cd127b7 commit 4e751eb
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 8 deletions.
17 changes: 17 additions & 0 deletions docs/modules/ROOT/pages/includes/quarkus-langchain4j-openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,23 @@ endif::add-copy-button-to-env-var[]
|


a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.organization-id]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.organization-id[quarkus.langchain4j.openai.organization-id]`


[.description]
--
OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional)

ifdef::add-copy-button-to-env-var[]
Environment variable: env_var_with_copy_button:+++QUARKUS_LANGCHAIN4J_OPENAI_ORGANIZATION_ID+++[]
endif::add-copy-button-to-env-var[]
ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_LANGCHAIN4J_OPENAI_ORGANIZATION_ID+++`
endif::add-copy-button-to-env-var[]
--|string
|


a| [[quarkus-langchain4j-openai_quarkus.langchain4j.openai.timeout]]`link:#quarkus-langchain4j-openai_quarkus.langchain4j.openai.timeout[quarkus.langchain4j.openai.timeout]`


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.jboss.resteasy.reactive.client.SseEvent;
import org.jboss.resteasy.reactive.client.SseEventFilter;
import org.jboss.resteasy.reactive.client.api.ClientLogger;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestContext;
import org.jboss.resteasy.reactive.client.spi.ResteasyReactiveClientRequestFilter;
import org.jboss.resteasy.reactive.common.providers.serialisers.AbstractJsonMessageBodyReader;

import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -298,6 +300,19 @@ public void aroundWriteTo(WriterInterceptorContext context) throws IOException,
}
}

class OpenAiOrganizationIdRequestFilter implements ResteasyReactiveClientRequestFilter {
private final String organizationId;

public OpenAiOrganizationIdRequestFilter(String organizationId) {
this.organizationId = organizationId;
}

@Override
public void filter(ResteasyReactiveClientRequestContext requestContext) {
requestContext.getHeaders().add("OpenAI-Organization", organizationId);
}
}

/**
* Introduce a custom logger as the stock one logs at the DEBUG level by default...
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ public OpenAiRestApi apply(Builder builder, OpenAiRestApi openAiRestApi) {
InetSocketAddress socketAddress = (InetSocketAddress) builder.proxy.address();
restApiBuilder.proxyAddress(socketAddress.getHostName(), socketAddress.getPort());
}

if (builder.organizationId != null) {
restApiBuilder.register(new OpenAiRestApi.OpenAiOrganizationIdRequestFilter(builder.organizationId));
}

return restApiBuilder.build(OpenAiRestApi.class);
} catch (URISyntaxException e) {
throw new RuntimeException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import java.util.Optional;
import java.util.function.Supplier;

import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;

import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiModerationModel;
Expand All @@ -18,8 +21,6 @@
import io.quarkiverse.langchain4j.openai.runtime.config.ImageModelConfig;
import io.quarkiverse.langchain4j.openai.runtime.config.Langchain4jOpenAiConfig;
import io.quarkiverse.langchain4j.openai.runtime.config.ModerationModelConfig;
import io.quarkus.runtime.ShutdownContext;
import io.quarkus.runtime.annotations.Recorder;
import io.smallrye.config.ConfigValidationException;

@Recorder
Expand All @@ -38,14 +39,15 @@ public Supplier<?> chatModel(Langchain4jOpenAiConfig runtimeConfig) {
.maxRetries(runtimeConfig.maxRetries())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(chatModelConfig.modelName())
.temperature(chatModelConfig.temperature())
.topP(chatModelConfig.topP())
.presencePenalty(chatModelConfig.presencePenalty())
.frequencyPenalty(chatModelConfig.frequencyPenalty());

if (chatModelConfig.maxTokens().isPresent()) {
runtimeConfig.organizationId().ifPresent(builder::organizationId);

if (chatModelConfig.maxTokens().isPresent()) {
builder.maxTokens(chatModelConfig.maxTokens().get());
}

Expand All @@ -69,13 +71,14 @@ public Supplier<?> streamingChatModel(Langchain4jOpenAiConfig runtimeConfig) {
.timeout(runtimeConfig.timeout())
.logRequests(firstOrDefault(false, chatModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, chatModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(chatModelConfig.modelName())
.temperature(chatModelConfig.temperature())
.topP(chatModelConfig.topP())
.presencePenalty(chatModelConfig.presencePenalty())
.frequencyPenalty(chatModelConfig.frequencyPenalty());

runtimeConfig.organizationId().ifPresent(builder::organizationId);

if (chatModelConfig.maxTokens().isPresent()) {
builder.maxTokens(chatModelConfig.maxTokens().get());
}
Expand All @@ -101,9 +104,10 @@ public Supplier<?> embeddingModel(Langchain4jOpenAiConfig runtimeConfig) {
.maxRetries(runtimeConfig.maxRetries())
.logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, embeddingModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(embeddingModelConfig.modelName());

runtimeConfig.organizationId().ifPresent(builder::organizationId);

return new Supplier<>() {
@Override
public Object get() {
Expand All @@ -125,9 +129,10 @@ public Supplier<?> moderationModel(Langchain4jOpenAiConfig runtimeConfig) {
.maxRetries(runtimeConfig.maxRetries())
.logRequests(firstOrDefault(false, moderationModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, moderationModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(moderationModelConfig.modelName());

runtimeConfig.organizationId().ifPresent(builder::organizationId);

return new Supplier<>() {
@Override
public Object get() {
Expand All @@ -149,14 +154,15 @@ public Supplier<?> imageModel(Langchain4jOpenAiConfig runtimeConfig) {
.maxRetries(runtimeConfig.maxRetries())
.logRequests(firstOrDefault(false, imageModelConfig.logRequests(), runtimeConfig.logRequests()))
.logResponses(firstOrDefault(false, imageModelConfig.logResponses(), runtimeConfig.logResponses()))

.modelName(imageModelConfig.modelName())
.size(imageModelConfig.size())
.quality(imageModelConfig.quality())
.style(imageModelConfig.style())
.responseFormat(imageModelConfig.responseFormat())
.user(imageModelConfig.user());

runtimeConfig.organizationId().ifPresent(builder::organizationId);

// we persist if the directory was set explicitly and the boolean flag was not set to false
// or if the boolean flag was set explicitly to true
Optional<Path> persistDirectory = Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ public interface Langchain4jOpenAiConfig {
*/
Optional<String> apiKey();

/**
* OpenAI Organization ID (https://platform.openai.com/docs/api-reference/organization-optional)
*/
Optional<String> organizationId();

/**
* Timeout for OpenAI calls
*/
Expand Down

0 comments on commit 4e751eb

Please sign in to comment.