Skip to content

Commit

Permalink
新增知识库检索
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj committed Dec 20, 2024
1 parent d55e6fc commit c8fa6ab
Show file tree
Hide file tree
Showing 8 changed files with 664 additions and 9 deletions.
36 changes: 36 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,39 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk

return rsp, nil
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/v2/knowledgebases/query")
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
request.URL = serviceURL
request.Method = "POST"
header.Set("Content-Type", "application/json")
request.Header = header
data, _ := json.Marshal(req)
request.Body = NopCloser(bytes.NewReader(data))
t.sdkConfig.BuildCurlCommand(&request)
resp, err := t.client.Do(&request)
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
defer resp.Body.Close()
requestID, err := checkHTTPResponse(resp)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}
data, err = io.ReadAll(resp.Body)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

rsp := QueryKnowledgeBaseResponse{}
if err := json.Unmarshal(data, &rsp); err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

return rsp, nil
}
87 changes: 87 additions & 0 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package appbuilder

import "time"

const (
ContentTypeRawText = "raw_text"
ContentTypeQA = "qa"
Expand Down Expand Up @@ -245,3 +247,88 @@ type DescribeChunksResponse struct {
NextMarker string `json:"nextMarker"`
MaxKeys int `json:"maxKeys"`
}

type MetadataFilter struct {
Operator string `json:"operator"`
Field string `json:"field"`
Value interface{} `json:"value"` // 因为Value的类型可以是str或list[str],所以我们使用interface{}来表示任何类型
}

type MetadataFilters struct {
Filters []MetadataFilter `json:"filters"`
Condition string `json:"condition"`
}

type PreRankingConfig struct {
Bm25Weight float64 `json:"bm25_weight"`
VecWeight float64 `json:"vec_weight"`
Bm25B float64 `json:"bm25_b"`
Bm25K1 float64 `json:"bm25_k1"`
Bm25MaxScore float64 `json:"bm25_max_score"`
}

type ElasticSearchRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Inputs []string `json:"inputs"`
ModelName string `json:"model_name"`
Top int `json:"top"`
}

type QueryPipelineConfig struct {
ID string `json:"id"`
Pipeline []interface{} `json:"pipeline"`
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
Type string `json:"type"`
Top int `json:"top"`
Skip int `json:"skip"`
MetadataFileters MetadataFilters `json:"metadata_fileters"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config"`
}

type RowLine struct {
Key string `json:"key"`
Index int `json:"index"`
Value string `json:"value"`
EnableIndexing bool `json:"enable_indexing"`
EnableResponse bool `json:"enable_response"`
}

type ChunkLocation struct {
PageNum []int `json:"paget_num"`
Box [][]int `json:"box"`
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]interface{} `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime time.Time `json:"create_time"`
UpdateTime time.Time `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations []ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
}

type QueryKnowledgeBaseResponse struct {
RequestId string `json:"requestId"`
Code string `json:"code"`
Message string `json:"message"`
Chunks []Chunk `json:"chunks"`
TotalCount int `json:"total_count"`
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public class AppBuilderConfig {
public static final String CHUNKS_DESCRIBE_URL = "/knowledgeBase?Action=DescribeChunks";
// 删除切片
public static final String CHUNK_DELETE_URL = "/knowledgeBase?Action=DeleteChunk";
// 知识库检索
public static final String QUERY_KNOWLEDGEBASE_URL = "/v2/knowledgebases/query";


// 运行rag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,4 +691,17 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I
ChunksDescribeResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBaseResponse(QueryKnowledgeBaseRequest request)
throws IOException, AppBuilderServerException {
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;

String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
postRequest.setHeader("Content-Type", "application/json");
HttpResponse<QueryKnowledgeBaseResponse> response = httpClient.execute(postRequest, QueryKnowledgeBaseResponse.class);
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}
}
Loading

0 comments on commit c8fa6ab

Please sign in to comment.