Skip to content

Commit

Permalink
Provide Tensor MaxSim reranker for Fusion operator (#1244)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Provide Tensor MaxSim reranker for Fusion operator
Support Multiple fusion operators

Issue link:#1179

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Test cases
  • Loading branch information
yangzq50 authored May 24, 2024
1 parent 0d81b48 commit aef0348
Show file tree
Hide file tree
Showing 20 changed files with 639 additions and 152 deletions.
16 changes: 15 additions & 1 deletion src/executor/fragment_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,21 @@ void FragmentBuilder::BuildFragments(PhysicalOperator *phys_op, PlanFragment *cu
current_fragment_ptr->SetFragmentType(FragmentType::kSerialMaterialize);
break;
}
case PhysicalOperatorType::kFusion:
case PhysicalOperatorType::kFusion: {
if (phys_op->left() == nullptr) {
UnrecoverableError(fmt::format("No input node of {}", phys_op->GetName()));
}
if (phys_op->left()->operator_type() == PhysicalOperatorType::kFusion) {
if (phys_op->right() != nullptr) {
UnrecoverableError("Fusion operator with fusion operator child shouldn't have right child.");
}
current_fragment_ptr->AddOperator(phys_op);
// call next Fusion operator
BuildFragments(phys_op->left(), current_fragment_ptr);
break;
}
[[fallthrough]];
}
case PhysicalOperatorType::kMergeAggregate:
case PhysicalOperatorType::kMergeHash:
case PhysicalOperatorType::kMergeLimit:
Expand Down
347 changes: 324 additions & 23 deletions src/executor/operator/physical_fusion.cpp

Large diffs are not rendered by default.

30 changes: 24 additions & 6 deletions src/executor/operator/physical_fusion.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module;
export module physical_fusion;

import stl;

import base_table_ref;
import query_context;
import operator_state;
import physical_operator;
Expand All @@ -30,10 +30,12 @@ import internal_types;
import data_type;

namespace infinity {
struct DataBlock;

export class PhysicalFusion final: public PhysicalOperator {
export class PhysicalFusion final : public PhysicalOperator {
public:
explicit PhysicalFusion(u64 id,
SharedPtr<BaseTableRef> base_table_ref,
UniquePtr<PhysicalOperator> left,
UniquePtr<PhysicalOperator> right,
SharedPtr<FusionExpression> fusion_expr,
Expand All @@ -42,23 +44,39 @@ public:

void Init() override;

bool Execute(QueryContext *query_context, OperatorState *operator_state) final;
bool Execute(QueryContext *query_context, OperatorState *operator_state) override;

SharedPtr<Vector<String>> GetOutputNames() const final { return left_->GetOutputNames(); };
SharedPtr<Vector<String>> GetOutputNames() const override { return output_names_; }

SharedPtr<Vector<SharedPtr<DataType>>> GetOutputTypes() const final { return left_->GetOutputTypes(); };
SharedPtr<Vector<SharedPtr<DataType>>> GetOutputTypes() const override { return output_types_; }

SizeT TaskletCount() override {
UnrecoverableError("Not implement: TaskletCount not Implement");
return 0;
}

void FillingTableRefs(HashMap<SizeT, SharedPtr<BaseTableRef>> &table_refs) override {
table_refs.insert({base_table_ref_->table_index_, base_table_ref_});
}

String ToString(i64 &space) const;

SharedPtr<BaseTableRef> base_table_ref_{};
SharedPtr<FusionExpression> fusion_expr_;

private:

bool ExecuteFirstOp(QueryContext *query_context, FusionOperatorState *fusion_operator_state) const;
bool ExecuteNotFirstOp(QueryContext *query_context, OperatorState *operator_state) const;
// RRF has multiple input source, must be first op
void ExecuteRRF(const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks, Vector<UniquePtr<DataBlock>> &output_data_block_array) const;
// MatchTensor may have multiple or single input source, can be first or not first op
void ExecuteMatchTensor(QueryContext *query_context,
const Map<u64, Vector<UniquePtr<DataBlock>>> &input_data_blocks,
Vector<UniquePtr<DataBlock>> &output_data_block_array) const;

String to_lower_method_;
SharedPtr<Vector<String>> output_names_;
SharedPtr<Vector<SharedPtr<DataType>>> output_types_;
};

} // namespace infinity
9 changes: 4 additions & 5 deletions src/executor/physical_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,17 +902,16 @@ UniquePtr<PhysicalOperator> PhysicalPlanner::BuildMatchTensorScan(const SharedPt
}

UniquePtr<PhysicalOperator> PhysicalPlanner::BuildFusion(const SharedPtr<LogicalNode> &logical_operator) const {
SharedPtr<LogicalFusion> logical_fusion = static_pointer_cast<LogicalFusion>(logical_operator);
const auto logical_fusion = static_pointer_cast<LogicalFusion>(logical_operator);
UniquePtr<PhysicalOperator> left_phy = nullptr, right_phy = nullptr;
auto left_logical_node = logical_operator->left_node();
if (left_logical_node.get() != nullptr) {
if (const auto &left_logical_node = logical_operator->left_node(); left_logical_node.get() != nullptr) {
left_phy = BuildPhysicalOperator(left_logical_node);
}
auto right_logical_node = logical_operator->right_node();
if (right_logical_node.get() != nullptr) {
if (const auto right_logical_node = logical_operator->right_node(); right_logical_node.get() != nullptr) {
right_phy = BuildPhysicalOperator(right_logical_node);
}
return MakeUnique<PhysicalFusion>(logical_fusion->node_id(),
logical_fusion->base_table_ref_,
std::move(left_phy),
std::move(right_phy),
logical_fusion->fusion_expr_,
Expand Down
2 changes: 1 addition & 1 deletion src/expression/fusion_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import third_party;
namespace infinity {

FusionExpression::FusionExpression(const String &method, SharedPtr<SearchOptions> options)
: BaseExpression(ExpressionType::kFusion, Vector<SharedPtr<BaseExpression>>()), method_(method), options_(options) {}
: BaseExpression(ExpressionType::kFusion, Vector<SharedPtr<BaseExpression>>()), method_(method), options_(std::move(options)) {}

String FusionExpression::ToString() const {
if (!alias_.empty()) {
Expand Down
6 changes: 3 additions & 3 deletions src/expression/search_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace infinity {
SearchExpression::SearchExpression(Vector<SharedPtr<MatchExpression>> &match_exprs,
Vector<SharedPtr<KnnExpression>> &knn_exprs,
Vector<SharedPtr<MatchTensorExpression>> &match_tensor_exprs,
SharedPtr<FusionExpression> fusion_expr)
Vector<SharedPtr<FusionExpression>> &fusion_exprs)
: BaseExpression(ExpressionType::kSearch, Vector<SharedPtr<BaseExpression>>()), match_exprs_(match_exprs), knn_exprs_(knn_exprs),
match_tensor_exprs_(match_tensor_exprs), fusion_expr_(fusion_expr) {}
match_tensor_exprs_(match_tensor_exprs), fusion_exprs_(fusion_exprs) {}

String SearchExpression::ToString() const {
if (!alias_.empty()) {
Expand Down Expand Up @@ -60,7 +60,7 @@ String SearchExpression::ToString() const {
cnt++;
oss << match_tensor_expr->ToString();
}
if (fusion_expr_.get() != nullptr) {
for (auto &fusion_expr_ : fusion_exprs_) {
if (cnt != 0)
oss << ", ";
cnt++;
Expand Down
4 changes: 2 additions & 2 deletions src/expression/search_expression.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public:
SearchExpression(Vector<SharedPtr<MatchExpression>> &match_exprs,
Vector<SharedPtr<KnnExpression>> &knn_exprs,
Vector<SharedPtr<MatchTensorExpression>> &match_tensor_exprs,
SharedPtr<FusionExpression> fusion_expr);
Vector<SharedPtr<FusionExpression>> &fusion_exprs);

inline DataType Type() const override { return DataType(LogicalType::kFloat); }

Expand All @@ -43,7 +43,7 @@ public:
Vector<SharedPtr<MatchExpression>> match_exprs_{};
Vector<SharedPtr<KnnExpression>> knn_exprs_{};
Vector<SharedPtr<MatchTensorExpression>> match_tensor_exprs_{};
SharedPtr<FusionExpression> fusion_expr_{};
Vector<SharedPtr<FusionExpression>> fusion_exprs_{};
};

} // namespace infinity
21 changes: 14 additions & 7 deletions src/parser/expr/search_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,22 @@ std::string SearchExpr::ToString() const {
oss << expr->ToString();
is_first = false;
}
if (fusion_expr_ != nullptr) {
oss << ", " << fusion_expr_->ToString();
for (auto &expr : fusion_exprs_) {
if (!is_first)
oss << ", ";
oss << expr->ToString();
is_first = false;
}
return oss.str();
}

void SearchExpr::SetExprs(std::vector<infinity::ParsedExpr *> *exprs) {
if (exprs == nullptr) {
ParserError("SearchExpr::SetExprs parameter is nullptr");
}
if (exprs_ != nullptr) {
ParserError("SearchExpr::SetExprs member exprs_ is not nullptr");
}
exprs_ = exprs;
for (ParsedExpr *expr : *exprs) {
AddExpr(expr);
Expand All @@ -65,8 +74,9 @@ void SearchExpr::Validate() const {
if (num_sub_expr <= 0) {
ParserError("Need at least one MATCH VECTOR / MATCH TENSOR / MATCH TEXT / QUERY expression");
} else if (num_sub_expr >= 2) {
if (fusion_expr_ == nullptr)
if (fusion_exprs_.empty()) {
ParserError("Need FUSION expr since there are multiple MATCH VECTOR / MATCH TENSOR / MATCH TEXT / QUERY expressions");
}
}
}

Expand All @@ -82,10 +92,7 @@ void SearchExpr::AddExpr(infinity::ParsedExpr *expr) {
match_tensor_exprs_.push_back(static_cast<MatchTensorExpr *>(expr));
break;
case ParsedExprType::kFusion:
if (fusion_expr_ != nullptr) {
ParserError("More than one FUSION expr");
}
fusion_expr_ = static_cast<FusionExpr *>(expr);
fusion_exprs_.push_back(static_cast<FusionExpr *>(expr));
break;
default:
ParserError("Invalid expr type for SEARCH");
Expand Down
2 changes: 1 addition & 1 deletion src/parser/expr/search_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SearchExpr : public ParsedExpr {
std::vector<MatchExpr *> match_exprs_{};
std::vector<KnnExpr *> knn_exprs_{};
std::vector<MatchTensorExpr *> match_tensor_exprs_{};
FusionExpr *fusion_expr_{};
std::vector<FusionExpr *> fusion_exprs_{};

private:
std::vector<infinity::ParsedExpr *> *exprs_{};
Expand Down
4 changes: 2 additions & 2 deletions src/planner/bind_context.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ public:
}
auto search_expr = (SearchExpr *)expr;

allow_distance = !search_expr->knn_exprs_.empty() && search_expr->fusion_expr_ == nullptr;
allow_score = !search_expr->match_exprs_.empty() || !search_expr->match_tensor_exprs_.empty() || search_expr->fusion_expr_ != nullptr;
allow_distance = !search_expr->knn_exprs_.empty() && search_expr->fusion_exprs_.empty();
allow_score = !search_expr->match_exprs_.empty() || !search_expr->match_tensor_exprs_.empty() || !(search_expr->fusion_exprs_.empty());
}

void AddSubqueryBinding(const String &name,
Expand Down
45 changes: 20 additions & 25 deletions src/planner/bound_select_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,65 +143,60 @@ SharedPtr<LogicalNode> BoundSelectStatement::BuildPlan(QueryContext *query_conte
return root;
} else {
SharedPtr<LogicalNode> root = nullptr;
SizeT num_children = search_expr_->match_exprs_.size() + search_expr_->knn_exprs_.size() + search_expr_->match_tensor_exprs_.size();
const SizeT num_children = search_expr_->match_exprs_.size() + search_expr_->knn_exprs_.size() + search_expr_->match_tensor_exprs_.size();
if (num_children <= 0) {
UnrecoverableError("SEARCH shall have at least one MATCH TEXT or MATCH VECTOR or MATCH TENSOR expression");
} else if (num_children >= 3) {
UnrecoverableError("SEARCH shall have at max two MATCH TEXT or MATCH VECTOR expression");
}

if (table_ref_ptr_->type() != TableRefType::kTable) {
UnrecoverableError("Not base table reference");
}
auto base_table_ref = static_pointer_cast<BaseTableRef>(table_ref_ptr_);
// FIXME: need check if there is subquery inside the where conditions
auto filter_expr = ComposeExpressionWithDelimiter(where_conditions_, ConjunctionType::kAnd);
auto common_query_filter =
MakeShared<CommonQueryFilter>(filter_expr, static_pointer_cast<BaseTableRef>(table_ref_ptr_), query_context->GetTxn()->BeginTS());
auto common_query_filter = MakeShared<CommonQueryFilter>(filter_expr, base_table_ref, query_context->GetTxn()->BeginTS());
Vector<SharedPtr<LogicalNode>> match_knn_nodes;
match_knn_nodes.reserve(search_expr_->match_exprs_.size());
match_knn_nodes.reserve(num_children);
for (auto &match_expr : search_expr_->match_exprs_) {
if (table_ref_ptr_->type() != TableRefType::kTable) {
UnrecoverableError("Not base table reference");
}
auto base_table_ref = static_pointer_cast<BaseTableRef>(table_ref_ptr_);
SharedPtr<LogicalMatch> matchNode = MakeShared<LogicalMatch>(bind_context->GetNewLogicalNodeId(), base_table_ref, match_expr);
matchNode->filter_expression_ = filter_expr;
matchNode->common_query_filter_ = common_query_filter;
match_knn_nodes.push_back(std::move(matchNode));
}
for (auto &match_tensor_expr : search_expr_->match_tensor_exprs_) {
if (table_ref_ptr_->type() != TableRefType::kTable) {
UnrecoverableError("Not base table reference");
}
auto base_table_ref = static_pointer_cast<BaseTableRef>(table_ref_ptr_);
auto match_tensor_node = MakeShared<LogicalMatchTensorScan>(bind_context->GetNewLogicalNodeId(), base_table_ref, match_tensor_expr);
match_tensor_node->filter_expression_ = filter_expr;
match_tensor_node->common_query_filter_ = common_query_filter;
match_tensor_node->InitExtraOptions();
match_knn_nodes.push_back(std::move(match_tensor_node));
}

bind_context->GenerateTableIndex();
for (auto &knn_expr : search_expr_->knn_exprs_) {
if (table_ref_ptr_->type() != TableRefType::kTable) {
UnrecoverableError("Not base table reference");
}
SharedPtr<LogicalKnnScan> knn_scan = BuildInitialKnnScan(table_ref_ptr_, knn_expr, query_context, bind_context);
knn_scan->filter_expression_ = filter_expr;
knn_scan->common_query_filter_ = common_query_filter;
match_knn_nodes.push_back(std::move(knn_scan));
}

if (search_expr_->fusion_expr_.get() != nullptr) {
SharedPtr<LogicalNode> fusionNode = MakeShared<LogicalFusion>(bind_context->GetNewLogicalNodeId(), search_expr_->fusion_expr_);
fusionNode->set_left_node(match_knn_nodes[0]);
if (!(search_expr_->fusion_exprs_.empty())) {
auto firstfusionNode = MakeShared<LogicalFusion>(bind_context->GetNewLogicalNodeId(), base_table_ref, search_expr_->fusion_exprs_[0]);
firstfusionNode->set_left_node(match_knn_nodes[0]);
if (match_knn_nodes.size() > 1)
fusionNode->set_right_node(match_knn_nodes[1]);
root = fusionNode;
firstfusionNode->set_right_node(match_knn_nodes[1]);
root = std::move(firstfusionNode);
// extra fusion nodes
for (u32 i = 1; i < search_expr_->fusion_exprs_.size(); ++i) {
auto extrafusionNode = MakeShared<LogicalFusion>(bind_context->GetNewLogicalNodeId(), base_table_ref, search_expr_->fusion_exprs_[i]);
extrafusionNode->set_left_node(root);
root = std::move(extrafusionNode);
}
} else {
root = match_knn_nodes[0];
root = std::move(match_knn_nodes[0]);
}

auto project = MakeShared<LogicalProject>(bind_context->GetNewLogicalNodeId(), projection_expressions_, projection_index_);
project->set_left_node(root);
root = project;
root = std::move(project);

return root;
}
Expand Down
9 changes: 5 additions & 4 deletions src/planner/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import between_expr;
import subquery_expr;
import match_expr;
import match_tensor_expr;
import fusion_expr;
import data_type;

import catalog;
Expand Down Expand Up @@ -834,7 +835,7 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildSearchExpr(const SearchExpr &ex
Vector<SharedPtr<MatchExpression>> match_exprs;
Vector<SharedPtr<KnnExpression>> knn_exprs;
Vector<SharedPtr<MatchTensorExpression>> match_tensor_exprs;
SharedPtr<FusionExpression> fusion_expr = nullptr;
Vector<SharedPtr<FusionExpression>> fusion_exprs;
for (MatchExpr *match_expr : expr.match_exprs_) {
match_exprs.push_back(MakeShared<MatchExpression>(match_expr->fields_, match_expr->matching_text_, match_expr->options_text_));
}
Expand All @@ -845,10 +846,10 @@ SharedPtr<BaseExpression> ExpressionBinder::BuildSearchExpr(const SearchExpr &ex
match_tensor_exprs.push_back(
static_pointer_cast<MatchTensorExpression>(BuildMatchTensorExpr(*match_tensor_expr, bind_context_ptr, depth, false)));
}
if (expr.fusion_expr_ != nullptr) {
fusion_expr = MakeShared<FusionExpression>(expr.fusion_expr_->method_, expr.fusion_expr_->options_);
for (FusionExpr *fusion_expr : expr.fusion_exprs_) {
fusion_exprs.push_back(MakeShared<FusionExpression>(fusion_expr->method_, fusion_expr->options_));
}
SharedPtr<SearchExpression> bound_search_expr = MakeShared<SearchExpression>(match_exprs, knn_exprs, match_tensor_exprs, fusion_expr);
SharedPtr<SearchExpression> bound_search_expr = MakeShared<SearchExpression>(match_exprs, knn_exprs, match_tensor_exprs, fusion_exprs);
return bound_search_expr;
}

Expand Down
5 changes: 2 additions & 3 deletions src/planner/node/logical_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ import internal_types;

namespace infinity {

LogicalFusion::LogicalFusion(u64 node_id,
SharedPtr<FusionExpression> fusion_expr)
: LogicalNode(node_id, LogicalNodeType::kFusion), fusion_expr_(fusion_expr) {}
LogicalFusion::LogicalFusion(const u64 node_id, SharedPtr<BaseTableRef> base_table_ref, SharedPtr<FusionExpression> fusion_expr)
: LogicalNode(node_id, LogicalNodeType::kFusion), base_table_ref_(std::move(base_table_ref)), fusion_expr_(std::move(fusion_expr)) {}

String LogicalFusion::ToString(i64 &space) const {
std::stringstream ss;
Expand Down
4 changes: 2 additions & 2 deletions src/planner/node/logical_fusion.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ namespace infinity {

export class LogicalFusion : public LogicalNode {
public:
explicit LogicalFusion(u64 node_id,
SharedPtr<FusionExpression> fusion_expr);
explicit LogicalFusion(u64 node_id, SharedPtr<BaseTableRef> base_table_ref, SharedPtr<FusionExpression> fusion_expr);

Vector<ColumnBinding> GetColumnBindings() const final { return left_node_->GetColumnBindings(); };

Expand All @@ -45,6 +44,7 @@ public:

inline String name() final { return "LogicalFusion"; }

SharedPtr<BaseTableRef> base_table_ref_{};
SharedPtr<FusionExpression> fusion_expr_{};
};

Expand Down
Loading

0 comments on commit aef0348

Please sign in to comment.