Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
# Conflicts:
#	spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java
  • Loading branch information
Manni, William authored and Manni, William committed Jul 29, 2024
2 parents 3fe103d + 6363352 commit ee9b28a
Show file tree
Hide file tree
Showing 343 changed files with 8,300 additions and 3,174 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ package.json

shell.log

.profiler
.profiler
/spring-ai-spring-boot-autoconfigure/nbproject/
/vector-stores/spring-ai-cassandra-store/nbproject/
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ private Builder() {
*/
public PdfDocumentReaderConfig.Builder withPageExtractedTextFormatter(
ExtractedTextFormatter pageExtractedTextFormatter) {
Assert.notNull(pagesPerDocument >= 0, "PageExtractedTextFormatter must not be null.");
Assert.notNull(pageExtractedTextFormatter, "PageExtractedTextFormatter must not be null.");
this.pageExtractedTextFormatter = pageExtractedTextFormatter;
return this;
}
Expand Down
7 changes: 3 additions & 4 deletions models/spring-ai-anthropic/pom.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
Expand Down Expand Up @@ -78,11 +79,9 @@
<dependency>
<groupId>com.fasterxml.jackson.dataformat</groupId>
<artifactId>jackson-dataformat-xml</artifactId>
<!-- <version>2.16.1</version> -->
<version>2.11.1</version>
<scope>test</scope>
</dependency>

</dependencies>

</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -31,26 +30,32 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
* The {@link ChatModel} implementation for the Anthropic service.
Expand All @@ -60,13 +65,11 @@
* @author Mariusz Bernacki
* @since 1.0.0
*/
public class AnthropicChatModel extends
AbstractFunctionCallSupport<AnthropicApi.AnthropicMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletionResponse>>
implements ChatModel {
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {

private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);

public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue();
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue();

public static final Integer DEFAULT_MAX_TOKENS = 500;

Expand Down Expand Up @@ -147,41 +150,97 @@ public ChatResponse call(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, false);

return this.retryTemplate.execute(ctx -> {
ResponseEntity<ChatCompletionResponse> completionEntity = this.callWithFunctionSupport(request);
return toChatResponse(completionEntity.getBody());
});
ResponseEntity<ChatCompletionResponse> completionEntity = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionEntity(request));

ChatResponse chatResponse = toChatResponse(completionEntity.getBody());

if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
}

return chatResponse;
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {

ChatCompletionRequest request = createRequest(prompt, true);

return this.retryTemplate.execute(ctx -> {
Flux<ChatCompletionResponse> response = this.retryTemplate
.execute(ctx -> this.anthropicApi.chatCompletionStream(request));

return response.switchMap(chatCompletionResponse -> {

Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);

if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
var toolCallConversation = handleToolCalls(prompt, chatResponse);
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
}

return response
.switchMap(chatCompletionResponse -> handleFunctionCallOrReturnStream(request,
Flux.just(ResponseEntity.of(Optional.of(chatCompletionResponse)))))
.map(ResponseEntity::getBody)
.map(this::toChatResponse);
return Mono.just(chatResponse);
});
}

private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {

if (chatCompletion == null) {
logger.warn("Null chat completion returned");
return new ChatResponse(List.of());
}

List<Generation> generations = chatCompletion.content().stream().map(content -> {
return new Generation(content.text(), Map.of())
.withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
}).toList();
List<Generation> generations = chatCompletion.content()
.stream()
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
.map(content -> {
new AssistantMessage(content.text(), Map.of());
return new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
})
.toList();

List<Generation> allGenerations = new ArrayList<>(generations);

List<ContentBlock> toolToUseList = chatCompletion.content()
.stream()
.filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
.toList();

if (!CollectionUtils.isEmpty(toolToUseList)) {
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();

for (ContentBlock toolToUse : toolToUseList) {

var functionCallId = toolToUse.id();
var functionName = toolToUse.name();
var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input());

toolCalls
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
}

AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
allGenerations.add(toolCallGeneration);
}

return new ChatResponse(generations, AnthropicChatResponseMetadata.from(chatCompletion));
return new ChatResponse(allGenerations, this.from(chatCompletion));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("stop-reason", result.stopReason())
.withKeyValue("stop-sequence", result.stopSequence())
.withKeyValue("type", result.type())
.build();
}

private String fromMediaData(Object mediaData) {
Expand All @@ -203,18 +262,47 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<AnthropicMessage> userMessages = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() != MessageType.SYSTEM)
.map(m -> {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(m.getContent())));
if (!CollectionUtils.isEmpty(m.getMedia())) {
List<ContentBlock> mediaContent = m.getMedia()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getContent())));
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia()
.stream()
.map(media -> new ContentBlock(media.getMimeType().toString(),
this.fromMediaData(media.getData())))
.toList();
contents.addAll(mediaContent);
}
}
return new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name()));
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getContent())) {
contentBlocks.add(new ContentBlock(message.getContent()));
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
}
}
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
}
else if (message.getMessageType() == MessageType.TOOL) {
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
.stream()
.map(media -> new ContentBlock(media.getMimeType().toString(),
this.fromMediaData(media.getData())))
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
toolResponse.responseData()))
.toList();
contents.addAll(mediaContent);
return new AnthropicMessage(toolResponses, Role.USER);
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
return new AnthropicMessage(contents, Role.valueOf(m.getMessageType().name()));
})
.toList();

Expand Down Expand Up @@ -265,74 +353,6 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
}).toList();
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
AnthropicMessage responseMessage, List<AnthropicMessage> conversationHistory) {

List<ContentBlock> toolToUseList = responseMessage.content()
.stream()
.filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE)
.toList();

List<ContentBlock> toolResults = new ArrayList<>();

for (ContentBlock toolToUse : toolToUseList) {

var functionCallId = toolToUse.id();
var functionName = toolToUse.name();
var functionArguments = toolToUse.input();

if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}

String functionResponse = this.functionCallbackRegister.get(functionName)
.call(ModelOptionsUtils.toJsonString(functionArguments));

toolResults.add(new ContentBlock(ContentBlockType.TOOL_RESULT, functionCallId, functionResponse));
}

// Add the function response to the conversation.
conversationHistory.add(new AnthropicMessage(toolResults, Role.USER));

// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
return ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
}

@Override
protected List<AnthropicMessage> doGetUserMessages(ChatCompletionRequest request) {
return request.messages();
}

@Override
protected AnthropicMessage doGetToolResponseMessage(ResponseEntity<ChatCompletionResponse> response) {
return new AnthropicMessage(response.getBody().content(), Role.ASSISTANT);
}

@Override
protected ResponseEntity<ChatCompletionResponse> doChatCompletion(ChatCompletionRequest request) {
return this.anthropicApi.chatCompletionEntity(request);
}

@SuppressWarnings("null")
@Override
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletionResponse> response) {
if (response == null || response.getBody() == null || CollectionUtils.isEmpty(response.getBody().content())) {
return false;
}
return response.getBody()
.content()
.stream()
.anyMatch(content -> content.type() == ContentBlock.ContentBlockType.TOOL_USE);
}

@Override
protected Flux<ResponseEntity<ChatCompletionResponse>> doChatCompletionStream(ChatCompletionRequest request) {

return this.anthropicApi.chatCompletionStream(request).map(Optional::ofNullable).map(ResponseEntity::of);
}

@Override
public ChatOptions getDefaultOptions() {
return AnthropicChatOptions.fromOptions(this.defaultOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ public void setFunctions(Set<String> functions) {
this.functions = functions;
}

@Override
public AnthropicChatOptions copy() {
return fromOptions(this);
}

public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
return builder().withModel(fromOptions.getModel())
.withMaxTokens(fromOptions.getMaxTokens())
Expand Down
Loading

0 comments on commit ee9b28a

Please sign in to comment.