Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

知识库新增检索功能 #673

Closed
wants to merge 20 commits into from
36 changes: 36 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,39 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk

return rsp, nil
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/v2/knowledgebases/query")
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
request.URL = serviceURL
request.Method = "POST"
header.Set("Content-Type", "application/json")
request.Header = header
data, _ := json.Marshal(req)
request.Body = NopCloser(bytes.NewReader(data))
t.sdkConfig.BuildCurlCommand(&request)
resp, err := t.client.Do(&request)
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
defer resp.Body.Close()
requestID, err := checkHTTPResponse(resp)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}
data, err = io.ReadAll(resp.Body)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

rsp := QueryKnowledgeBaseResponse{}
if err := json.Unmarshal(data, &rsp); err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

return rsp, nil
}
87 changes: 87 additions & 0 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package appbuilder

import "time"

const (
ContentTypeRawText = "raw_text"
ContentTypeQA = "qa"
Expand Down Expand Up @@ -245,3 +247,88 @@ type DescribeChunksResponse struct {
NextMarker string `json:"nextMarker"`
MaxKeys int `json:"maxKeys"`
}

type MetadataFilter struct {
Operator string `json:"operator"`
Field string `json:"field"`
Value interface{} `json:"value"` // 因为Value的类型可以是str或list[str],所以我们使用interface{}来表示任何类型
}

type MetadataFilters struct {
Filters []MetadataFilter `json:"filters"`
Condition string `json:"condition"`
}

type PreRankingConfig struct {
Bm25Weight float64 `json:"bm25_weight"`
VecWeight float64 `json:"vec_weight"`
Bm25B float64 `json:"bm25_b"`
Bm25K1 float64 `json:"bm25_k1"`
Bm25MaxScore float64 `json:"bm25_max_score"`
}

type ElasticSearchRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Inputs []string `json:"inputs"`
ModelName string `json:"model_name"`
Top int `json:"top"`
}

type QueryPipelineConfig struct {
ID string `json:"id"`
Pipeline []interface{} `json:"pipeline"`
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
Type string `json:"type"`
Top int `json:"top"`
Skip int `json:"skip"`
MetadataFileters MetadataFilters `json:"metadata_fileters"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config"`
}

type RowLine struct {
Key string `json:"key"`
Index int `json:"index"`
Value string `json:"value"`
EnableIndexing bool `json:"enable_indexing"`
EnableResponse bool `json:"enable_response"`
}

type ChunkLocation struct {
PageNum []int `json:"paget_num"`
Box [][]int `json:"box"`
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]interface{} `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations []ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
}

type QueryKnowledgeBaseResponse struct {
RequestId string `json:"requestId"`
Code string `json:"code"`
Message string `json:"message"`
Chunks []Chunk `json:"chunks"`
TotalCount int `json:"total_count"`
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public class AppBuilderConfig {
public static final String CHUNKS_DESCRIBE_URL = "/knowledgeBase?Action=DescribeChunks";
// 删除切片
public static final String CHUNK_DELETE_URL = "/knowledgeBase?Action=DeleteChunk";
// 知识库检索
public static final String QUERY_KNOWLEDGEBASE_URL = "/v2/knowledgebases/query";


// 运行rag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,4 +691,17 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I
ChunksDescribeResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBaseResponse(QueryKnowledgeBaseRequest request)
throws IOException, AppBuilderServerException {
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;

String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
postRequest.setHeader("Content-Type", "application/json");
HttpResponse<QueryKnowledgeBaseResponse> response = httpClient.execute(postRequest, QueryKnowledgeBaseResponse.class);
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}
}
Loading
Loading