Skip to content

Commit

Permalink
add ReplaceNodeScanWithIndexSeek
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcui committed Dec 1, 2024
1 parent 64ec55a commit 06410d1
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/cypher/execution_plan/ops/op_node_by_label_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class NodeByLabelScan : public OpBase {

const std::string& GetLabel() { return label_; }

const SymbolTable * GetSymtab() {return sym_tab_;}

CYPHER_DEFINE_VISITABLE()

CYPHER_DEFINE_CONST_VISITABLE()
Expand Down
4 changes: 3 additions & 1 deletion src/cypher/execution_plan/optimization/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "execution_plan/optimization/locate_node_by_indexed_prop_v2.h"
#include "execution_plan/optimization/locate_node_by_prop_range_filter.h"
#include "execution_plan/optimization/parallel_traversal_v2.h"
#include "execution_plan/optimization/rewrite_label_scan.h"

namespace cypher {

Expand All @@ -48,7 +49,8 @@ class PassManager {
all_passes_.emplace_back(new ParallelTraversal());
all_passes_.emplace_back(new ParallelTraversalV2());
all_passes_.emplace_back(new LocateNodeByVidV2());
all_passes_.emplace_back(new LocateNodeByIndexedPropV2());
//all_passes_.emplace_back(new LocateNodeByIndexedPropV2());
all_passes_.emplace_back(new ReplaceNodeScanWithIndexSeek(ctx));
all_passes_.emplace_back(new LocateNodeByPropRangeFilter());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class PropertyFilterDetector : public cypher::OptimizationFilterVisitorImpl {
std::string cur_field_;
std::set<lgraph::FieldData> cur_properties_;

std::any visit(geax::frontend::BOr* node) override {
std::any visit(geax::frontend::BAnd* node) override {
ACCEPT_AND_CHECK_WITH_PASS_MSG(node->left());
ACCEPT_AND_CHECK_WITH_PASS_MSG(node->right());
return geax::frontend::GEAXErrorCode::GEAX_OPTIMIZATION_PASS;
Expand All @@ -69,6 +69,7 @@ class PropertyFilterDetector : public cypher::OptimizationFilterVisitorImpl {
}

std::any visit(geax::frontend::BEqual* node) override {
cur_properties_.clear();
ACCEPT_AND_CHECK_WITH_PASS_MSG(node->left());
ACCEPT_AND_CHECK_WITH_PASS_MSG(node->right());
if (!cur_properties_.empty()) {
Expand All @@ -87,6 +88,7 @@ class PropertyFilterDetector : public cypher::OptimizationFilterVisitorImpl {
}

std::any visit(geax::frontend::BIn* node) override {
cur_properties_.clear();
ACCEPT_AND_CHECK_WITH_PASS_MSG(node->left());
ACCEPT_AND_CHECK_WITH_PASS_MSG(node->right());
if (!cur_properties_.empty()) {
Expand Down
145 changes: 145 additions & 0 deletions src/cypher/execution_plan/optimization/rewrite_label_scan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/**
* Copyright 2022 AntGroup CO., Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "tools/lgraph_log.h"
#include "core/data_type.h"
#include "cypher/execution_plan/ops/op_filter.h"
#include "cypher/execution_plan/ops/op_node_index_seek.h"
#include "cypher/execution_plan/ops/op_node_by_label_scan.h"
#include "cypher/execution_plan/optimization/opt_pass.h"
#include "cypher/execution_plan/optimization/property_filter_detector.h"

namespace cypher {

typedef std::unordered_map<
std::string, std::unordered_map<
std::string,std::set<lgraph::FieldData>>> FilterCollections;

class ReplaceNodeScanWithIndexSeek : public OptPass {
private:
RTContext *ctx_ = nullptr;
const lgraph::SchemaInfo *si_ = nullptr;

void Impl(OpBase *root) {
OpBase *op_filter = nullptr;
FilterCollections filters;
if (FindNodePropFilter(root, op_filter, filters)) {
Replace(op_filter, filters);
}
}

bool FindNodePropFilter(OpBase *root, OpBase *&op_filter, FilterCollections &filters) {
auto op = root;
if (op->type == OpType::FILTER ) {
auto filter = dynamic_cast<OpFilter *>(op);
if (_CheckPropFilter(filter, filters)) {
op_filter = op;
return true;
}
}

for (auto child : op->children) {
if (FindNodePropFilter(child, op_filter, filters)) return true;
}

return false;
}

void Replace(OpBase *root, FilterCollections &filter_collections) {
if (root->type == OpType::NODE_BY_LABEL_SCAN) {
auto scan = dynamic_cast<NodeByLabelScan *>(root);
auto label = scan->GetLabel();
auto node = scan->GetNode();
auto n = node->Alias();
if (!filter_collections.count(n)) {
return;
}
auto& filters = filter_collections.at(n);
auto schema = si_->v_schema_manager.GetSchema(label);
if (!schema) {
return;
}
auto pk = schema->GetPrimaryField();
if (filters.count(pk)) {
std::vector<lgraph::FieldData> values;
for (auto& val : filters.at(pk)) {
values.push_back(val);
}
auto parent = root->parent;
auto op_node_index_seek = new NodeIndexSeek(node, scan->GetSymtab(), pk, values);
op_node_index_seek->parent = parent;
parent->RemoveChild(root);
OpBase::FreeStream(root);
parent->AddChild(op_node_index_seek);
return;
}
for (auto& [k, set] : filters) {
if (k == pk) {
continue;
}
if (!schema->TryGetFieldExtractor(k)->GetVertexIndex()) {
continue;
}
std::vector<lgraph::FieldData> values;
for (auto& val : set) {
values.push_back(val);
}
auto parent = root->parent;
auto op_node_index_seek = new NodeIndexSeek(node, scan->GetSymtab(), k, values);
op_node_index_seek->parent = parent;
parent->RemoveChild(root);
OpBase::FreeStream(root);
parent->AddChild(op_node_index_seek);
return;
}
return;
}
for (auto child : root->children) {
Replace(child, filter_collections);
}
}

bool _CheckPropFilter(OpFilter *&op_filter, FilterCollections &filters) {
auto filter = op_filter->Filter();
CYPHER_THROW_ASSERT(filter->Type() == lgraph::Filter::GEAX_EXPR_FILTER);
auto geax_filter = ((lgraph::GeaxExprFilter *)filter.get())->GetArithExpr();
geax::frontend::Expr *expr = geax_filter.expr_;
PropertyFilterDetector detector;
if (!detector.Build(expr)) return false;
filters = detector.GetProperties();
if (filters.empty()) return false;
return true;
}
public:
ReplaceNodeScanWithIndexSeek(RTContext *ctx)
: OptPass(typeid(ReplaceNodeScanWithIndexSeek).name()), ctx_(ctx) {}
bool Gate() override { return true; }
int Execute(OpBase *root) override {
if (ctx_->graph_.empty()) {
return 0;
}
ctx_->ac_db_ = std::make_unique<lgraph::AccessControlledDB>(
ctx_->galaxy_->OpenGraph(ctx_->user_, ctx_->graph_));
lgraph_api::GraphDB db(ctx_->ac_db_.get(), true);
auto txn = db.CreateReadTxn();
si_ = &txn.GetTxn()->GetSchemaInfo();
Impl(root);
txn.Abort();
return 0;
}
};

}
9 changes: 9 additions & 0 deletions src/cypher/parser/cypher_base_visitor_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ void CypherBaseVisitorV2::PropertyExtractor(geax::frontend::ElementFiller *fille
}
}
for (auto &label : labels) {
if (isVertex) {
if (!node_property_.count(label)) {
node_property_[label] = {};
}
} else {
if (!rel_property_.count(label)) {
rel_property_[label] = {};
}
}
for (auto field : fields) {
if (isVertex) {
node_property_[label].emplace(field);
Expand Down

0 comments on commit 06410d1

Please sign in to comment.