From 6270f2edc0010acb9ae8d2fb57d471f4e60ce2fb Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Tue, 16 Apr 2024 14:36:43 +0300 Subject: [PATCH] Support tools in Athropic module --- .../anthropic/AnthropicRestApi.java | 14 ++++++- .../anthropic/QuarkusAnthropicClient.java | 37 ++++++++++++++++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/AnthropicRestApi.java b/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/AnthropicRestApi.java index 8bf979188..2d808c3af 100644 --- a/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/AnthropicRestApi.java +++ b/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/AnthropicRestApi.java @@ -61,7 +61,10 @@ class ApiMetadata { @HeaderParam("anthropic-version") public final String anthropicVersion; - private ApiMetadata(String apiKey, String anthropicVersion) { + @HeaderParam("anthropic-beta") + public final String beta; + + private ApiMetadata(String apiKey, String anthropicVersion, String beta) { if ((apiKey == null) || apiKey.isBlank()) { throw new IllegalArgumentException("apiKey cannot be null or blank"); } @@ -72,6 +75,7 @@ private ApiMetadata(String apiKey, String anthropicVersion) { this.apiKey = apiKey; this.anthropicVersion = anthropicVersion; + this.beta = beta; } public static ApiMetadata.Builder builder() { @@ -81,9 +85,10 @@ public static ApiMetadata.Builder builder() { public static class Builder { private String apiKey; private String anthropicVersion; + private String beta; public ApiMetadata build() { - return new ApiMetadata(this.apiKey, this.anthropicVersion); + return new ApiMetadata(this.apiKey, this.anthropicVersion, this.beta); } public ApiMetadata.Builder apiKey(String apiKey) { @@ -95,6 +100,11 @@ public ApiMetadata.Builder anthropicVersion(String anthropicVersion) { this.anthropicVersion = anthropicVersion; return this; } + + public ApiMetadata.Builder beta(String beta) { + this.beta = beta; + return this; + } } } diff --git a/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/QuarkusAnthropicClient.java b/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/QuarkusAnthropicClient.java index c1df2d624..131df5cbf 100644 --- a/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/QuarkusAnthropicClient.java +++ b/anthropic/runtime/src/main/java/io/quarkiverse/langchain4j/anthropic/QuarkusAnthropicClient.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.anthropic; import static dev.langchain4j.internal.Utils.isNotNullOrEmpty; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.model.anthropic.AnthropicMapper.toFinishReason; import static java.util.Collections.synchronizedList; import static java.util.stream.Collectors.joining; @@ -26,7 +27,11 @@ import dev.langchain4j.model.anthropic.AnthropicCreateMessageRequest; import dev.langchain4j.model.anthropic.AnthropicCreateMessageResponse; import dev.langchain4j.model.anthropic.AnthropicHttpException; +import dev.langchain4j.model.anthropic.AnthropicMessage; +import dev.langchain4j.model.anthropic.AnthropicMessageContent; import dev.langchain4j.model.anthropic.AnthropicStreamingData; +import dev.langchain4j.model.anthropic.AnthropicToolResultContent; +import dev.langchain4j.model.anthropic.AnthropicToolUseContent; import dev.langchain4j.model.anthropic.AnthropicUsage; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; @@ -39,6 +44,7 @@ import io.vertx.core.http.HttpClientResponse; public class QuarkusAnthropicClient extends AnthropicClient { + public static final String BETA = "tools-2024-04-04"; private final String apiKey; private final String anthropicVersion; private final AnthropicRestApi restApi; @@ -65,21 +71,40 @@ public QuarkusAnthropicClient(Builder builder) { @Override public AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageRequest request) { - return restApi.createMessage(request, createMetadata()); + return restApi.createMessage(request, createMetadata(request)); } @Override public void createMessage(AnthropicCreateMessageRequest request, StreamingResponseHandler handler) { - restApi.streamMessage(request, createMetadata()) + restApi.streamMessage(request, createMetadata(request)) .subscribe() .withSubscriber(new AnthropicStreamingSubscriber(handler)); } - private AnthropicRestApi.ApiMetadata createMetadata() { - return AnthropicRestApi.ApiMetadata.builder() + private AnthropicRestApi.ApiMetadata createMetadata(AnthropicCreateMessageRequest request) { + var builder = AnthropicRestApi.ApiMetadata.builder() .apiKey(apiKey) - .anthropicVersion(anthropicVersion) - .build(); + .anthropicVersion(anthropicVersion); + if (hasTools(request)) { + builder.beta(BETA); + } + return builder.build(); + } + + private boolean hasTools(AnthropicCreateMessageRequest request) { + if (!isNullOrEmpty(request.getTools())) { + return true; + } + List messages = request.getMessages(); + for (AnthropicMessage message : messages) { + List contents = message.getContent(); + for (AnthropicMessageContent content : contents) { + if ((content instanceof AnthropicToolUseContent) || (content instanceof AnthropicToolResultContent)) { + return true; + } + } + } + return false; } private static class AnthropicStreamingSubscriber implements MultiSubscriber {