Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update Spring AI dependency to 1.0.0-M3 #57

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
<spring-boot.version>3.3.3</spring-boot.version>

<!-- Spring AI -->
<spring-ai.version>1.0.0-M2</spring-ai.version>
<spring-ai.version>1.0.0-M3</spring-ai.version>
<dashscope-sdk-java.version>2.15.1</dashscope-sdk-java.version>

<!-- plugin versions -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@

package com.alibaba.cloud.ai.advisor;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.client.advisor.api.*;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentRetriever;
import org.springframework.ai.model.Content;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/**
Expand All @@ -37,7 +40,7 @@
* @since 1.0.0-M2
*/

public class DocumentRetrievalAdvisor implements RequestResponseAdvisor {
public class DocumentRetrievalAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

private static final String DEFAULT_USER_TEXT_ADVISE = """
请记住以下材料,他们可能对回答问题有帮助。
Expand All @@ -46,24 +49,85 @@ public class DocumentRetrievalAdvisor implements RequestResponseAdvisor {
---------------------
""";

private static final int DEFAULT_ORDER = 0;

public static String RETRIEVED_DOCUMENTS = "documents";

private final DocumentRetriever retriever;

private final String userTextAdvise;

private final boolean protectFromBlocking;

private final int order;

public DocumentRetrievalAdvisor(DocumentRetriever retriever) {
this.retriever = retriever;
this.userTextAdvise = DEFAULT_USER_TEXT_ADVISE;
this(retriever, DEFAULT_USER_TEXT_ADVISE);
}

public DocumentRetrievalAdvisor(DocumentRetriever retriever, String userTextAdvise) {
this(retriever, userTextAdvise, true);
}

public DocumentRetrievalAdvisor(DocumentRetriever retriever, String userTextAdvise, boolean protectFromBlocking) {
this(retriever, userTextAdvise, protectFromBlocking, DEFAULT_ORDER);
}

public DocumentRetrievalAdvisor(DocumentRetriever retriever, String userTextAdvise, boolean protectFromBlocking,
int order) {
this.retriever = retriever;
this.userTextAdvise = userTextAdvise;
this.protectFromBlocking = protectFromBlocking;
this.order = order;
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {

advisedRequest = this.before(advisedRequest);

AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);

return this.after(advisedResponse);
}

@Override
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {

// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(request -> chain.nextAroundStream(request))
: chain.nextAroundStream(this.before(advisedRequest));
// @formatter:on

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}

@Override
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public int getOrder() {
return this.order;
}

private AdvisedRequest before(AdvisedRequest request) {

var context = new HashMap<>(request.adviseContext());

List<Document> documents = retriever.retrieve(request.userText());

context.put(RETRIEVED_DOCUMENTS, documents);
Expand All @@ -79,21 +143,35 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
.withSystemText(this.userTextAdvise)
.withSystemParams(advisedUserParams)
.withUserText(request.userText())
.withAdviseContext(context)
.build();
}

@Override
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response);
return chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)).build();
private AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder()
.from(advisedResponse.response())
.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
return fluxResponse.map(cr -> {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
return chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS)).build();
});
/**
* Controls whether {@link DocumentRetrievalAdvisor#after(AdvisedResponse)} should be
* executed.<br />
* Called only on Flux elements that contain a finish reason. Usually the last element
* in the Flux. The response advisor can modify the elements before they are returned
* to the client.<br />
* Inspired by
* {@link org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor}.
*/
private Predicate<AdvisedResponse> onFinishReason() {

return (advisedResponse) -> advisedResponse.response()
.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
import com.alibaba.cloud.ai.model.RerankResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.client.advisor.api.*;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
Expand All @@ -34,8 +33,11 @@
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.toList;
Expand All @@ -48,7 +50,7 @@
* @since 1.0.0-M2
*/

public class RetrievalRerankAdvisor implements RequestResponseAdvisor {
public class RetrievalRerankAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

private static final Logger logger = LoggerFactory.getLogger(RetrievalRerankAdvisor.class);

Expand All @@ -64,6 +66,8 @@ public class RetrievalRerankAdvisor implements RequestResponseAdvisor {

private static final Double DEFAULT_MIN_SCORE = 0.1;

private static final int DEFAULT_ORDER = 0;

private final VectorStore vectorStore;

private final RerankModel rerankModel;
Expand All @@ -74,6 +78,10 @@ public class RetrievalRerankAdvisor implements RequestResponseAdvisor {

private final Double minScore;

private final boolean protectFromBlocking;

private final int order;

public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";

public static final String FILTER_EXPRESSION = "qa_filter_expression";
Expand All @@ -94,6 +102,16 @@ public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel,

public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest,
String userTextAdvise, Double minScore) {
this(vectorStore, rerankModel, searchRequest, userTextAdvise, minScore, true);
}

public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest,
String userTextAdvise, Double minScore, boolean protectFromBlocking) {
this(vectorStore, rerankModel, searchRequest, userTextAdvise, minScore, protectFromBlocking, DEFAULT_ORDER);
}

public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel, SearchRequest searchRequest,
String userTextAdvise, Double minScore, boolean protectFromBlocking, int order) {
Assert.notNull(vectorStore, "The vectorStore must not be null!");
Assert.notNull(rerankModel, "The rerankModel must not be null!");
Assert.notNull(searchRequest, "The searchRequest must not be null!");
Expand All @@ -104,53 +122,51 @@ public RetrievalRerankAdvisor(VectorStore vectorStore, RerankModel rerankModel,
this.userTextAdvise = userTextAdvise;
this.searchRequest = searchRequest;
this.minScore = minScore;
this.protectFromBlocking = protectFromBlocking;
this.order = order;
}

@Override
public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context) {
// 1. Advise the system text.
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {

var searchRequestToUse = SearchRequest.from(this.searchRequest)
.withQuery(request.userText())
.withFilterExpression(doGetFilterExpression(context));
advisedRequest = this.before(advisedRequest);

// 2. Search for similar documents in the vector store.
logger.debug("searchRequestToUse: {}", searchRequestToUse);
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
logger.debug("retrieved documents: {}", documents);
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);

// 3. Rerank documents for query
documents = doRerank(request, documents);

context.put(RETRIEVED_DOCUMENTS, documents);

// 4. Create the context from the documents.
String documentContext = documents.stream()
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

// 5. Advise the user parameters.
Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
advisedUserParams.put("question_answer_context", documentContext);
return this.after(advisedResponse);
}

return AdvisedRequest.from(request).withUserText(advisedUserText).withUserParams(advisedUserParams).build();
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {

// This can be executed by both blocking and non-blocking Threads
// E.g. a command line or Tomcat blocking Thread implementation
// or by a WebFlux dispatch in a non-blocking manner.
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
// @formatter:off
Mono.just(advisedRequest)
.publishOn(Schedulers.boundedElastic())
.map(this::before)
.flatMapMany(request -> chain.nextAroundStream(request))
: chain.nextAroundStream(this.before(advisedRequest));
// @formatter:on

return advisedResponses.map(ar -> {
if (onFinishReason().test(ar)) {
ar = after(ar);
}
return ar;
});
}

@Override
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response);
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return chatResponseBuilder.build();
public String getName() {
return this.getClass().getSimpleName();
}

@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
return fluxResponse.map(cr -> {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return chatResponseBuilder.build();
});
public int getOrder() {
return this.order;
}

protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
Expand Down Expand Up @@ -182,4 +198,68 @@ protected List<Document> doRerank(AdvisedRequest request, List<Document> documen
.collect(toList());
}

private AdvisedRequest before(AdvisedRequest request) {

var context = new HashMap<>(request.adviseContext());

// 1. Advise the system text.
String advisedUserText = request.userText() + System.lineSeparator() + this.userTextAdvise;

var searchRequestToUse = SearchRequest.from(this.searchRequest)
.withQuery(request.userText())
.withFilterExpression(doGetFilterExpression(context));

// 2. Search for similar documents in the vector store.
logger.debug("searchRequestToUse: {}", searchRequestToUse);
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
logger.debug("retrieved documents: {}", documents);

// 3. Rerank documents for query
documents = doRerank(request, documents);

context.put(RETRIEVED_DOCUMENTS, documents);

// 4. Create the context from the documents.
String documentContext = documents.stream()
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));

// 5. Advise the user parameters.
Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
advisedUserParams.put("question_answer_context", documentContext);

return AdvisedRequest.from(request)
.withUserText(advisedUserText)
.withUserParams(advisedUserParams)
.withAdviseContext(context)
.build();
}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder()
.from(advisedResponse.response())
.withMetadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

/**
* Controls whether {@link RetrievalRerankAdvisor#after(AdvisedResponse)} should be
* executed.<br />
* Called only on Flux elements that contain a finish reason. Usually the last element
* in the Flux. The response advisor can modify the elements before they are returned
* to the client.<br />
* Inspired by
* {@link org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor}.
*/
private Predicate<AdvisedResponse> onFinishReason() {

return (advisedResponse) -> advisedResponse.response()
.getResults()
.stream()
.filter(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()))
.findFirst()
.isPresent();
}

}
Loading