Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新增知识库检索 #679

Merged
merged 1 commit into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,39 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk

return rsp, nil
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/knowledgebases/query")
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
request.URL = serviceURL
request.Method = "POST"
header.Set("Content-Type", "application/json")
request.Header = header
data, _ := json.Marshal(req)
request.Body = NopCloser(bytes.NewReader(data))
t.sdkConfig.BuildCurlCommand(&request)
resp, err := t.client.Do(&request)
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
defer resp.Body.Close()
requestID, err := checkHTTPResponse(resp)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}
data, err = io.ReadAll(resp.Body)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

rsp := QueryKnowledgeBaseResponse{}
if err := json.Unmarshal(data, &rsp); err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

return rsp, nil
}
86 changes: 86 additions & 0 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,89 @@ type DescribeChunksResponse struct {
NextMarker string `json:"nextMarker"`
MaxKeys int `json:"maxKeys"`
}

type MetadataFilter struct {
Operator string `json:"operator"`
Field string `json:"field,omitempty"`
Value any `json:"value"`
}

type MetadataFilters struct {
Filters []MetadataFilter `json:"filters"`
Condition string `json:"condition"`
}

type PreRankingConfig struct {
Bm25Weight float64 `json:"bm25_weight"`
VecWeight float64 `json:"vec_weight"`
Bm25B float64 `json:"bm25_b"`
Bm25K1 float64 `json:"bm25_k1"`
Bm25MaxScore float64 `json:"bm25_max_score"`
}

type ElasticSearchRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Inputs []string `json:"inputs"`
ModelName string `json:"model_name"`
Top int `json:"top"`
}

type QueryPipelineConfig struct {
ID string `json:"id"`
Pipeline []any `json:"pipeline"`
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
Type *string `json:"type,omitempty"`
Top int `json:"top,omitempty"`
Skip int `json:"skip,omitempty"`
MetadataFileters MetadataFilters `json:"metadata_fileters,omitempty"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config,omitempty"`
}

type RowLine struct {
Key string `json:"key"`
Index int `json:"index"`
Value string `json:"value"`
EnableIndexing bool `json:"enable_indexing"`
EnableResponse bool `json:"enable_response"`
}

type ChunkLocation struct {
PageNum []int `json:"paget_num"`
Box [][]int `json:"box"`
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
}

type QueryKnowledgeBaseResponse struct {
RequestId string `json:"requestId"`
Code string `json:"code"`
Message string `json:"message"`
Chunks []Chunk `json:"chunks"`
TotalCount int `json:"total_count"`
}
92 changes: 92 additions & 0 deletions go/appbuilder/knowledge_base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package appbuilder

import (
"bytes"
"encoding/json"
"fmt"
"os"
"strings"
Expand Down Expand Up @@ -1478,3 +1479,94 @@ func TestChunk(t *testing.T) {
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}

func TestQueryKnowledgeBase(t *testing.T) {
t.Parallel() // 并发运行
var logBuffer bytes.Buffer

os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG")

log := func(format string, args ...any) {
fmt.Fprintf(&logBuffer, format+"\n", args...)
}

config, err := NewSDKConfig("", os.Getenv(SecretKey))
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new http client config failed: %v", err)
}

client, err := NewKnowledgeBase(config)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new Knowledge base instance failed")
}

jsonStr := `
{
"type": "fulltext",
"query": "民法典第三条",
"knowledgebase_ids": [
"70c6375a-1595-41f2-9a3b-e81bc9060b7f"
],
"metadata_filters": {
"filters": [
],
"condition": "or"
},
"pipeline_config": {
"id": "pipeline_001",
"pipeline": [
{
"name": "step1",
"type": "elastic_search",
"threshold": 0.1,
"top": 400,
"pre_ranking": {
"bm25_weight": 0.25,
"vec_weight": 0.75,
"bm25_b": 0.75,
"bm25_k1": 1.5,
"bm25_max_score": 50
}
},
{
"name": "step2",
"type": "ranking",
"inputs": ["step1"],
"model_name": "ranker-v1",
"top": 20
}
]
},
"top": 5,
"skip": 0
}`

var request QueryKnowledgeBaseRequest
err = json.Unmarshal([]byte(jsonStr), &request)
if err != nil {
fmt.Println("unmarshal tool error:", err)
}

queryKnowledgeBaseResponse, err := client.QueryKnowledgeBase(request)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("create chunk failed: %v", err)
}
chunkID := queryKnowledgeBaseResponse.Chunks[0].ChunkID
log("query got chunk ID: %s", chunkID)
if len(chunkID) == 0 {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("query knowledge base failed: %v", err)
}

// 如果测试失败,则输出缓冲区中的日志
if t.Failed() {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
fmt.Println(logBuffer.String())
} else { // else 紧跟在右大括号后面
// 测试通过,打印文件名和测试函数名
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public class AppBuilderConfig {
public static final String CHUNKS_DESCRIBE_URL = "/knowledgeBase?Action=DescribeChunks";
// 删除切片
public static final String CHUNK_DELETE_URL = "/knowledgeBase?Action=DeleteChunk";
// 知识库检索
public static final String QUERY_KNOWLEDGEBASE_URL = "/knowledgebases/query";


// 运行rag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,4 +691,34 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I
ChunksDescribeResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBase(QueryKnowledgeBaseRequest request)
throws IOException, AppBuilderServerException {
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;

String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
postRequest.setHeader("Content-Type", "application/json");
HttpResponse<QueryKnowledgeBaseResponse> response = httpClient.execute(postRequest,
QueryKnowledgeBaseResponse.class);
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBase(String query, String type, Integer top, Integer skip,
String[] knowledgebaseIDs, QueryKnowledgeBaseRequest.MetadataFilters filters,
QueryKnowledgeBaseRequest.QueryPipelineConfig pipelineConfig)
throws IOException, AppBuilderServerException {
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;
QueryKnowledgeBaseRequest request = new QueryKnowledgeBaseRequest(query, type, top, skip, knowledgebaseIDs, filters, pipelineConfig);
String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
postRequest.setHeader("Content-Type", "application/json");
HttpResponse<QueryKnowledgeBaseResponse> response = httpClient.execute(postRequest,
QueryKnowledgeBaseResponse.class);
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}
}
Loading
Loading