From 2e8f138f4d22bcbc8215ccc922af857d39c9456e Mon Sep 17 00:00:00 2001 From: Eric Deandrea Date: Fri, 2 Feb 2024 12:02:57 -0500 Subject: [PATCH] Only pass Authorization/api-key headers when necessary --- .../openai/QuarkusRestApiResource.java | 16 ++-- .../langchain4j/openai/OpenAiRestApi.java | 34 ++++++--- .../openai/QuarkusOpenAiClient.java | 76 +++++++++++-------- .../openai/test/OpenAiRestApiSmokeTest.java | 6 +- 4 files changed, 77 insertions(+), 55 deletions(-) diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java index 051be47f6..78691a536 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/QuarkusRestApiResource.java @@ -66,7 +66,7 @@ public String chatSync() { return restApi.blockingChatCompletion( createChatCompletionRequest("Write a short 1 paragraph funny poem about segmentation fault"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .content(); @@ -78,7 +78,7 @@ public Uni chatAsync() { return restApi .createChatCompletion(createChatCompletionRequest("Write a short 1 paragraph funny poem about Unicode"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .map(ChatCompletionResponse::content); @@ -91,7 +91,7 @@ public Multi chatStreaming() { return restApi.streamingChatCompletion( createChatCompletionRequest("Write a short 1 paragraph funny poem about Enterprise Java"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .map(r -> { @@ -124,7 +124,7 @@ public String languageSync() { return restApi.blockingCompletion( createCompletionRequest("Write a short 1 paragraph funny poem about segmentation fault"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .text(); @@ -136,7 +136,7 @@ public Uni languageAsync() { return restApi .completion(createCompletionRequest("Write a short 1 paragraph funny poem about Unicode"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .map(CompletionResponse::text); @@ -149,7 +149,7 @@ public Multi languageStreaming() { return restApi.streamingCompletion( createCompletionRequest("Write a short 1 paragraph funny poem about Enterprise Java"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .map(r -> { @@ -171,7 +171,7 @@ public Multi languageStreaming() { public List embeddingSync() { return restApi.blockingEmbedding(createEmbeddingRequest("Your text string goes here"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .embedding(); @@ -183,7 +183,7 @@ public Uni> embeddingAsync() { return restApi .embedding(createEmbeddingRequest("Your text string goes here"), OpenAiRestApi.ApiMetadata.builder() - .apiKey(token) + .openAiApiKey(token) .organizationId(organizationId) .build()) .map(EmbeddingResponse::embedding); diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java index 1a5260276..62f49c609 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/OpenAiRestApi.java @@ -416,17 +416,17 @@ class ApiMetadata { public final String authorization; @HeaderParam("api-key") - public final String apiKey; + public final String azureApiKey; @QueryParam("api-version") public final String apiVersion; @HeaderParam("OpenAI-Organization") public final String organizationId; - private ApiMetadata(String authorization, String apiKey, + private ApiMetadata(String openaiApiKey, String azureApiKey, String apiVersion, String organizationId) { - this.authorization = authorization; - this.apiKey = apiKey; + this.authorization = (openaiApiKey != null) ? "Bearer " + openaiApiKey : null; + this.azureApiKey = azureApiKey; this.apiVersion = apiVersion; this.organizationId = organizationId; } @@ -436,20 +436,30 @@ public static ApiMetadata.Builder builder() { } public static class Builder { - private String apiKey; + private String azureApiKey; + private String openAiApiKey; private String apiVersion; private String organizationId; public ApiMetadata build() { - return (apiKey == null) ? new ApiMetadata(null, null, apiVersion, organizationId) - : new ApiMetadata( - "Bearer " + apiKey, // typical OpenAI authentication - apiKey, // used by AzureAI - apiVersion, organizationId); + if ((azureApiKey != null) && (openAiApiKey != null)) { + return new ApiMetadata(openAiApiKey, azureApiKey, apiVersion, organizationId); + } else if (azureApiKey != null) { + return new ApiMetadata(null, azureApiKey, apiVersion, organizationId); + } else if (openAiApiKey != null) { + return new ApiMetadata(openAiApiKey, null, apiVersion, organizationId); + } + + return new ApiMetadata(null, null, apiVersion, organizationId); + } + + public ApiMetadata.Builder azureApiKey(String azureApiKey) { + this.azureApiKey = azureApiKey; + return this; } - public ApiMetadata.Builder apiKey(String apiKey) { - this.apiKey = apiKey; + public ApiMetadata.Builder openAiApiKey(String openAiApiKey) { + this.openAiApiKey = openAiApiKey; return this; } diff --git a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java index 0a7b0d161..40603b8aa 100644 --- a/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java +++ b/openai/openai-common/runtime/src/main/java/io/quarkiverse/langchain4j/openai/QuarkusOpenAiClient.java @@ -48,7 +48,8 @@ */ public class QuarkusOpenAiClient extends OpenAiClient { - private final String apiKey; + private final String azureApiKey; + private final String openaiApiKey; private final String apiVersion; private final String organizationId; @@ -56,8 +57,8 @@ public class QuarkusOpenAiClient extends OpenAiClient { private static final Map cache = new ConcurrentHashMap<>(); - public QuarkusOpenAiClient(String apiKey) { - this(new Builder().openAiApiKey(apiKey)); + public QuarkusOpenAiClient(String openaiApiKey) { + this(new Builder().openAiApiKey(openaiApiKey)); } public static Builder builder() { @@ -69,7 +70,8 @@ public static void clearCache() { } private QuarkusOpenAiClient(Builder builder) { - this.apiKey = determineApiKey(builder); + this.azureApiKey = builder.azureApiKey; + this.openaiApiKey = builder.openAiApiKey; this.apiVersion = builder.apiVersion; this.organizationId = builder.organizationId; // cache the client the builder could be called with the same parameters from multiple models @@ -106,15 +108,6 @@ public OpenAiRestApi apply(Builder builder, OpenAiRestApi openAiRestApi) { } - private static String determineApiKey(Builder builder) { - var result = builder.openAiApiKey; - if (result != null) { - return result; - } - result = builder.azureApiKey; - return result; - } - @Override public SyncOrAsyncOrStreaming completion(CompletionRequest request) { return new SyncOrAsyncOrStreaming<>() { @@ -123,7 +116,8 @@ public CompletionResponse execute() { return restApi.blockingCompletion( CompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -137,7 +131,8 @@ public AsyncResponseHandling onResponse(Consumer responseHan public Uni get() { return restApi.completion(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -155,7 +150,8 @@ public StreamingResponseHandling onPartialResponse( public Multi get() { return restApi.streamingCompletion(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -178,7 +174,8 @@ public ChatCompletionResponse execute() { return restApi.blockingChatCompletion( ChatCompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -192,7 +189,8 @@ public AsyncResponseHandling onResponse(Consumer respons public Uni get() { return restApi.createChatCompletion(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -210,7 +208,8 @@ public StreamingResponseHandling onPartialResponse( public Multi get() { return restApi.streamingChatCompletion(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -232,7 +231,8 @@ public String execute() { return restApi .blockingChatCompletion(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -249,7 +249,8 @@ public Uni get() { .createChatCompletion( ChatCompletionRequest.builder().from(request).stream(null).build(), OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -270,7 +271,8 @@ public Multi get() { .streamingChatCompletion( ChatCompletionRequest.builder().from(request).stream(true).build(), OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -300,7 +302,8 @@ public SyncOrAsync embedding(EmbeddingRequest request) { public EmbeddingResponse execute() { return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -314,7 +317,8 @@ public AsyncResponseHandling onResponse(Consumer responseHand public Uni get() { return restApi.embedding(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -335,7 +339,8 @@ public SyncOrAsync> embedding(String input) { public List execute() { return restApi.blockingEmbedding(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -350,7 +355,8 @@ public AsyncResponseHandling onResponse(Consumer> responseHandler) { public Uni> get() { return restApi.embedding(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -369,7 +375,8 @@ public SyncOrAsync moderation(ModerationRequest request) { public ModerationResponse execute() { return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -383,7 +390,8 @@ public AsyncResponseHandling onResponse(Consumer responseHan public Uni get() { return restApi.moderation(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -405,7 +413,8 @@ public SyncOrAsync moderation(String input) { public ModerationResult execute() { return restApi.blockingModeration(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -420,7 +429,8 @@ public AsyncResponseHandling onResponse(Consumer responseHandl public Uni get() { return restApi.moderation(request, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()) @@ -439,7 +449,8 @@ public SyncOrAsync imagesGeneration(GenerateImagesReques public GenerateImagesResponse execute() { return restApi.blockingImagesGenerations(generateImagesRequest, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); @@ -453,7 +464,8 @@ public AsyncResponseHandling onResponse(Consumer respons public Uni get() { return restApi.imagesGenerations(generateImagesRequest, OpenAiRestApi.ApiMetadata.builder() - .apiKey(apiKey) + .azureApiKey(azureApiKey) + .openAiApiKey(openaiApiKey) .apiVersion(apiVersion) .organizationId(organizationId) .build()); diff --git a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java index f8197c4b3..b0f80eb6c 100644 --- a/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java +++ b/openai/openai-vanilla/deployment/src/test/java/io/quarkiverse/langchain4j/openai/test/OpenAiRestApiSmokeTest.java @@ -52,7 +52,7 @@ void happyPath() throws URISyntaxException { OpenAiRestApi restApi = createClient(); ChatCompletionResponse response = restApi.blockingChatCompletion(ChatCompletionRequest.builder().build(), - OpenAiRestApi.ApiMetadata.builder().apiKey(TOKEN).organizationId(ORGANIZATION).build()); + OpenAiRestApi.ApiMetadata.builder().openAiApiKey(TOKEN).organizationId(ORGANIZATION).build()); assertThat(response).isNotNull(); wireMockServer.verify(WiremockUtils.chatCompletionRequestPattern(TOKEN, ORGANIZATION)); @@ -69,7 +69,7 @@ void server500() throws URISyntaxException { OpenAiRestApi restApi = createClient(); assertThatThrownBy(() -> restApi.blockingChatCompletion(ChatCompletionRequest.builder().build(), - OpenAiRestApi.ApiMetadata.builder().apiKey(TOKEN).build())) + OpenAiRestApi.ApiMetadata.builder().openAiApiKey(TOKEN).build())) .isInstanceOf( OpenAiHttpException.class) .hasMessage("This is a dummy error message"); @@ -100,7 +100,7 @@ void server200ButAPIError() throws URISyntaxException { OpenAiRestApi restApi = createClient(); assertThatThrownBy(() -> restApi.blockingChatCompletion(ChatCompletionRequest.builder().build(), - OpenAiRestApi.ApiMetadata.builder().apiKey(TOKEN).build())) + OpenAiRestApi.ApiMetadata.builder().openAiApiKey(TOKEN).build())) .isInstanceOf( OpenAiApiException.class); }