Skip to content

Commit

Permalink
Enable mutiple fusion in http api (#1642)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
change the implement of select of http api to enable mutiple fusion 

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring
  • Loading branch information
Ami11111 authored Aug 13, 2024
1 parent d0308d5 commit 24ec1eb
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 123 deletions.
39 changes: 12 additions & 27 deletions python/test_pysdk/http_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def output(
self.output_res = []
self._output = output
self._filter = ""
self._fusion = {}
self._fusion = []
self._knn = {}
self._match = {}
self._match_tensor = {}
Expand Down Expand Up @@ -554,40 +554,25 @@ def knn(self, fields, query_vector, element_type, metric_type, top_k):
return self

def fusion(self, method="", option="", optional_match_tensor: CommonMatchTensorExpr = None):
if len(self._fusion):
tmp = self._fusion
self._fusion = {}
self._fusion["fusion"] = tmp

self._fusion["method"] = method
_fusion = {}
_fusion["method"] = method
if len(option) :
self._fusion["option"] = option
_fusion["option"] = option
if method == "match_tensor":
self._fusion["optional_match_tensor"] = {}
_fusion["optional_match_tensor"] = {}
vector_column_name = optional_match_tensor.vector_column_name
if isinstance(vector_column_name, list):
pass
else:
vector_column_name = [vector_column_name]
self._fusion["optional_match_tensor"]["fields"] = vector_column_name
self._fusion["optional_match_tensor"]["query_tensor"] = optional_match_tensor.embedding_data
_fusion["optional_match_tensor"]["fields"] = vector_column_name
_fusion["optional_match_tensor"]["query_tensor"] = optional_match_tensor.embedding_data
if optional_match_tensor.extra_option:
self._fusion["optional_match_tensor"]["options"] = optional_match_tensor.extra_option
self._fusion["optional_match_tensor"]["element_type"] = type_transfrom[optional_match_tensor.embedding_data_type]
self._fusion["optional_match_tensor"]["search_method"] = optional_match_tensor.method_type
_fusion["optional_match_tensor"]["options"] = optional_match_tensor.extra_option
_fusion["optional_match_tensor"]["element_type"] = type_transfrom[optional_match_tensor.embedding_data_type]
_fusion["optional_match_tensor"]["search_method"] = optional_match_tensor.method_type

if len(self._knn):
self._fusion["knn"] = self._knn
self._knn = {}
if len(self._match):
self._fusion["match"] = self._match
self._match = {}
if len(self._match_tensor):
self._fusion["match_tensor"] = self._match_tensor
self._match_tensor = {}
if len(self._match_sparse):
self._fusion["match_sparse"] = self._match_sparse
self._match_sparse = {}
self._fusion.append(_fusion)
return self

def to_result(self):
Expand Down Expand Up @@ -680,7 +665,7 @@ def update(self, filter="", update={},):

class database_result(http_adapter):
def __init__(self, list = [], error_code = ErrorCode.OK, database_name = "" ,columns=[], table_name = "",
index_list = [], output = ["*"], filter="", fusion={}, knn={}, match = {}, match_tensor = {}, match_sparse = {}, output_res = []):
index_list = [], output = ["*"], filter="", fusion=[], knn={}, match = {}, match_tensor = {}, match_sparse = {}, output_res = []):
self.db_names = list
self.error_code = error_code
self.database_name = database_name # get database
Expand Down
1 change: 0 additions & 1 deletion python/test_pysdk/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,6 @@ def test_sparse_knn_with_index(self, check_data):
res = db_obj.drop_table("test_sparse_knn_with_index", ConflictType.Error)
assert res.error_code == ErrorCode.OK

@pytest.mark.usefixtures("skip_if_http")
@pytest.mark.parametrize("check_data", [{"file_name": "tensor_maxsim.csv",
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
def test_with_multiple_fusion(self, check_data):
Expand Down
138 changes: 44 additions & 94 deletions src/network/http/http_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
response["error_message"] = "HTTP Body isn't json object";
}

SizeT match_expr_count = 0;
SizeT fusion_expr_count = 0;
Vector<ParsedExpr *> *output_columns{nullptr};
Vector<ParsedExpr *> *search_exprs{nullptr};
ParsedExpr *filter{nullptr};
Expand Down Expand Up @@ -127,81 +129,80 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
return;
}
} else if (IsEqual(key, "fusion")) {
if (search_expr != nullptr) {
if (search_expr == nullptr) {
search_expr = new SearchExpr();
}
auto &fusion_json_list = elem.value();
if (fusion_json_list.type() != nlohmann::json::value_t::array) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
response["error_message"] = "Fusion field should be list";
return;
}
search_expr = new SearchExpr();
auto &fusion_json = elem.value();
if (!ParseFusion(*search_exprs, fusion_json, http_status, response)) {
// error
return;
for (auto &fusion_json : fusion_json_list) {
const auto fusion_expr = ParseFusion(fusion_json, http_status, response);
if (fusion_expr == nullptr) {
return;
}
search_exprs->push_back(fusion_expr);
++fusion_expr_count;
}
} else if (IsEqual(key, "knn")) {
if (search_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
return;
if (search_expr == nullptr) {
search_expr = new SearchExpr();
}
search_expr = new SearchExpr();
auto &knn_json = elem.value();
const auto knn_expr = ParseKnn(knn_json, http_status, response);
if (knn_expr == nullptr) {
return;
}
search_exprs->push_back(knn_expr);
++match_expr_count;
} else if (IsEqual(key, "match")) {
if (search_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
return;
if (search_expr == nullptr) {
search_expr = new SearchExpr();
}
search_expr = new SearchExpr();
auto &match_json = elem.value();
const auto match_expr = ParseMatch(match_json, http_status, response);
if (match_expr == nullptr) {
return;
}
search_exprs->push_back(match_expr);
++match_expr_count;
} else if (IsEqual(key, "match_tensor")) {
if (search_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
return;
if (search_expr == nullptr) {
search_expr = new SearchExpr();
}
search_expr = new SearchExpr();
auto &match_json = elem.value();
const auto match_expr = ParseMatchTensor(match_json, http_status, response);
if (match_expr == nullptr) {
return;
}
search_exprs->push_back(match_expr);
++match_expr_count;
} else if (IsEqual(key, "match_sparse")) {
if (search_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
return;
if (search_expr == nullptr) {
search_expr = new SearchExpr();
}
search_expr = new SearchExpr();
auto &match_json = elem.value();
const auto match_expr = ParseMatchSparse(match_json, http_status, response);
if (match_expr == nullptr) {
return;
}
search_exprs->push_back(match_expr);
++match_expr_count;
} else {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Unknown expression: " + key;
return;
}
}

if (match_expr_count > 1 && fusion_expr_count == 0) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "More than one knn or match experssion with no fusion experssion!";
return;
}

if (search_exprs != nullptr && !search_exprs->empty()) {
search_expr->SetExprs(search_exprs);
search_exprs = nullptr;
Expand Down Expand Up @@ -314,71 +315,26 @@ Vector<ParsedExpr *> *HTTPSearch::ParseOutput(const nlohmann::json &output_list,
output_columns = nullptr;
return res;
}
bool HTTPSearch::ParseFusion(Vector<ParsedExpr *> &search_exprs,
const nlohmann::json &fusion_json_object,
FusionExpr* HTTPSearch::ParseFusion(const nlohmann::json &fusion_json_object,
HTTPStatus &http_status,
nlohmann::json &response) {
if (!fusion_json_object.is_object()) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = fmt::format("Fusion expression must be a json object: {}", fusion_json_object);
return false;
return nullptr;
}
// case 1. child fusion
// case 2. knn, match, match tensor

u32 method_cnt = 0;
u32 child_fusion_cnt = 0;
u32 child_match_cnt = 0;

UniquePtr<FusionExpr> fusion_expr = nullptr;
for (const auto &expression : fusion_json_object.items()) {
String key = expression.key();
ToLower(key);
if (IsEqual(key, "fusion")) {
if (child_fusion_cnt++) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Fusion child is already given";
return false;
}
auto &child_fusion_json = expression.value();
if (!ParseFusion(search_exprs, child_fusion_json, http_status, response)) {
return false;
}
} else if (IsEqual(key, "knn")) {
++child_match_cnt;
auto &knn_json = expression.value();
const auto knn_expr = ParseKnn(knn_json, http_status, response);
if (knn_expr == nullptr) {
return false;
}
search_exprs.push_back(knn_expr);
} else if (IsEqual(key, "match")) {
++child_match_cnt;
auto &match_json = expression.value();
const auto match_expr = ParseMatch(match_json, http_status, response);
if (match_expr == nullptr) {
return false;
}
search_exprs.push_back(match_expr);
} else if (IsEqual(key, "match_tensor")) {
++child_match_cnt;
auto &match_tensor_json = expression.value();
const auto match_tensor_expr = ParseMatchTensor(match_tensor_json, http_status, response);
if (match_tensor_expr == nullptr) {
return false;
}
search_exprs.push_back(match_tensor_expr);
} else if (IsEqual(key, "match_sparse")) {
++child_match_cnt;
auto &match_sparse_json = expression.value();
const auto match_tensor_expr = ParseMatchSparse(match_sparse_json, http_status, response);
if (match_tensor_expr == nullptr) {
return false;
}
search_exprs.push_back(match_tensor_expr);
} else if (IsEqual(key, "method")) {
if (IsEqual(key, "method")) {
if (method_cnt++) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Method is already given";
return false;
return nullptr;
}
if (!fusion_expr) {
fusion_expr = MakeUnique<FusionExpr>();
Expand All @@ -396,30 +352,24 @@ bool HTTPSearch::ParseFusion(Vector<ParsedExpr *> &search_exprs,
auto &match_tensor_json = expression.value();
const auto match_tensor_expr = ParseMatchTensor(match_tensor_json, http_status, response);
if (match_tensor_expr == nullptr) {
return false;
return nullptr;
}
fusion_expr->match_tensor_expr_.reset(match_tensor_expr);
} else {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Error fusion clause";
return false;
return nullptr;
}
}
if ((child_fusion_cnt == 0 and child_match_cnt == 0) or (child_fusion_cnt > 0 and child_match_cnt > 0)) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Error fusion clause";
return false;
}

if (fusion_expr) {
if (fusion_expr->method_.empty()) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Error fusion clause : empty method";
return false;
return nullptr;
}
//fusion_expr->JobAfterParser();
search_exprs.push_back(fusion_expr.release());
}
return true;
return fusion_expr.release();
}
KnnExpr *HTTPSearch::ParseKnn(const nlohmann::json &knn_json_object, HTTPStatus &http_status, nlohmann::json &response) {
if (!knn_json_object.is_object()) {
Expand Down
3 changes: 2 additions & 1 deletion src/network/http/http_search.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import third_party;
import parsed_expr;
import knn_expr;
import match_expr;
import fusion_expr;
import match_tensor_expr;
import match_sparse_expr;
import infinity;
Expand All @@ -40,7 +41,7 @@ public:

static ParsedExpr *ParseFilter(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static Vector<ParsedExpr *> *ParseOutput(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static bool ParseFusion(Vector<ParsedExpr *> &search_exprs, const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static FusionExpr *ParseFusion(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static KnnExpr *ParseKnn(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static MatchExpr *ParseMatch(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
static MatchTensorExpr *ParseMatchTensor(const nlohmann::json &json_object, HTTPStatus &http_status, nlohmann::json &response);
Expand Down

0 comments on commit 24ec1eb

Please sign in to comment.