Skip to content

Commit

Permalink
Support tools in Athropic module
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed Apr 16, 2024
1 parent 93b850e commit 6270f2e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ class ApiMetadata {
@HeaderParam("anthropic-version")
public final String anthropicVersion;

private ApiMetadata(String apiKey, String anthropicVersion) {
@HeaderParam("anthropic-beta")
public final String beta;

private ApiMetadata(String apiKey, String anthropicVersion, String beta) {
if ((apiKey == null) || apiKey.isBlank()) {
throw new IllegalArgumentException("apiKey cannot be null or blank");
}
Expand All @@ -72,6 +75,7 @@ private ApiMetadata(String apiKey, String anthropicVersion) {

this.apiKey = apiKey;
this.anthropicVersion = anthropicVersion;
this.beta = beta;
}

public static ApiMetadata.Builder builder() {
Expand All @@ -81,9 +85,10 @@ public static ApiMetadata.Builder builder() {
public static class Builder {
private String apiKey;
private String anthropicVersion;
private String beta;

public ApiMetadata build() {
return new ApiMetadata(this.apiKey, this.anthropicVersion);
return new ApiMetadata(this.apiKey, this.anthropicVersion, this.beta);
}

public ApiMetadata.Builder apiKey(String apiKey) {
Expand All @@ -95,6 +100,11 @@ public ApiMetadata.Builder anthropicVersion(String anthropicVersion) {
this.anthropicVersion = anthropicVersion;
return this;
}

public ApiMetadata.Builder beta(String beta) {
this.beta = beta;
return this;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.anthropic;

import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.anthropic.AnthropicMapper.toFinishReason;
import static java.util.Collections.synchronizedList;
import static java.util.stream.Collectors.joining;
Expand All @@ -26,7 +27,11 @@
import dev.langchain4j.model.anthropic.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.AnthropicHttpException;
import dev.langchain4j.model.anthropic.AnthropicMessage;
import dev.langchain4j.model.anthropic.AnthropicMessageContent;
import dev.langchain4j.model.anthropic.AnthropicStreamingData;
import dev.langchain4j.model.anthropic.AnthropicToolResultContent;
import dev.langchain4j.model.anthropic.AnthropicToolUseContent;
import dev.langchain4j.model.anthropic.AnthropicUsage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
Expand All @@ -39,6 +44,7 @@
import io.vertx.core.http.HttpClientResponse;

public class QuarkusAnthropicClient extends AnthropicClient {
public static final String BETA = "tools-2024-04-04";
private final String apiKey;
private final String anthropicVersion;
private final AnthropicRestApi restApi;
Expand All @@ -65,21 +71,40 @@ public QuarkusAnthropicClient(Builder builder) {

@Override
public AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageRequest request) {
return restApi.createMessage(request, createMetadata());
return restApi.createMessage(request, createMetadata(request));
}

@Override
public void createMessage(AnthropicCreateMessageRequest request, StreamingResponseHandler<AiMessage> handler) {
restApi.streamMessage(request, createMetadata())
restApi.streamMessage(request, createMetadata(request))
.subscribe()
.withSubscriber(new AnthropicStreamingSubscriber(handler));
}

private AnthropicRestApi.ApiMetadata createMetadata() {
return AnthropicRestApi.ApiMetadata.builder()
private AnthropicRestApi.ApiMetadata createMetadata(AnthropicCreateMessageRequest request) {
var builder = AnthropicRestApi.ApiMetadata.builder()
.apiKey(apiKey)
.anthropicVersion(anthropicVersion)
.build();
.anthropicVersion(anthropicVersion);
if (hasTools(request)) {
builder.beta(BETA);
}
return builder.build();
}

private boolean hasTools(AnthropicCreateMessageRequest request) {
if (!isNullOrEmpty(request.getTools())) {
return true;
}
List<AnthropicMessage> messages = request.getMessages();
for (AnthropicMessage message : messages) {
List<AnthropicMessageContent> contents = message.getContent();
for (AnthropicMessageContent content : contents) {
if ((content instanceof AnthropicToolUseContent) || (content instanceof AnthropicToolResultContent)) {
return true;
}
}
}
return false;
}

private static class AnthropicStreamingSubscriber implements MultiSubscriber<AnthropicStreamingData> {
Expand Down

0 comments on commit 6270f2e

Please sign in to comment.