diff --git a/src/cypher/execution_plan/ops/op_expand_all.h b/src/cypher/execution_plan/ops/op_expand_all.h index 04ca410de9..2b2ed20af1 100644 --- a/src/cypher/execution_plan/ops/op_expand_all.h +++ b/src/cypher/execution_plan/ops/op_expand_all.h @@ -60,6 +60,10 @@ class ExpandAll : public OpBase { if (neighbor_->Label().empty()) return true; auto nbr_it = ctx->txn_->GetTxn()->GetVertexIterator(eit_->GetNbr(expand_direction_)); while (ctx->txn_->GetTxn()->GetVertexLabel(nbr_it) != neighbor_->Label()) { + if (ctx->per_node_limit_.has_value() && expand_count_ > ctx->per_node_limit_.value()) { + return false; + } + expand_count_ += 1; eit_->Next(); if (!eit_->IsValid()) return false; nbr_it.Goto(eit_->GetNbr(expand_direction_)); @@ -81,9 +85,12 @@ class ExpandAll : public OpBase { if (state_ == ExpandAllResetted) { /* Start node iterator may be invalid, such as when the start is an argument * produced by OPTIONAL MATCH. */ + expand_count_ = 1; if (start_->PullVid() < 0) return OP_REFRESH; _InitializeEdgeIter(ctx); - while (_CheckToSkipEdge(ctx)) { + while (_CheckToSkipEdge(ctx) && (!ctx->per_node_limit_.has_value() || + expand_count_ <= ctx->per_node_limit_.value())) { + expand_count_ += 1; eit_->Next(); } if (!eit_->IsValid() || !_FilterNeighborLabel(ctx)) return OP_REFRESH; @@ -98,9 +105,17 @@ class ExpandAll : public OpBase { // The iterators are set, keep on consuming. pattern_graph_->VisitedEdges().Erase(*eit_); do { + if (ctx->per_node_limit_.has_value() && expand_count_ > ctx->per_node_limit_.value()) { + break; + } + expand_count_ += 1; eit_->Next(); } while (_CheckToSkipEdge(ctx)); - if (!eit_->IsValid() || !_FilterNeighborLabel(ctx)) return OP_REFRESH; + if ((ctx->per_node_limit_.has_value() && expand_count_ > ctx->per_node_limit_.value()) || + !eit_->IsValid() || !_FilterNeighborLabel(ctx)) { + neighbor_->PushVid(-1); + return OP_REFRESH; + } neighbor_->PushVid(eit_->GetNbr(expand_direction_)); pattern_graph_->VisitedEdges().Add(*eit_); _DumpForDebug(); @@ -119,6 +134,7 @@ class ExpandAll : public OpBase { bool expand_into_; ExpandTowards expand_direction_; std::shared_ptr edge_filter_ = nullptr; + size_t expand_count_; /* ExpandAllStates * Different states in which ExpandAll can be at. */ diff --git a/src/cypher/execution_plan/ops/op_var_len_expand.h b/src/cypher/execution_plan/ops/op_var_len_expand.h index c1a4ffd8c4..b846285c8f 100644 --- a/src/cypher/execution_plan/ops/op_var_len_expand.h +++ b/src/cypher/execution_plan/ops/op_var_len_expand.h @@ -34,7 +34,7 @@ namespace cypher { /* Variable Length Expand */ class VarLenExpand : public OpBase { - void _InitializeEdgeIter(RTContext *ctx, int64_t vid, lgraph::EIter &eit) { + void _InitializeEdgeIter(RTContext *ctx, int64_t vid, lgraph::EIter &eit, size_t &count) { auto &types = relp_->Types(); auto iter_type = lgraph::EIter::NA; switch (expand_direction_) { @@ -49,6 +49,7 @@ class VarLenExpand : public OpBase { break; } eit.Initialize(ctx->txn_->GetTxn().get(), iter_type, vid, types); + count = 1; } #if 0 // 20210704 @@ -177,13 +178,20 @@ class VarLenExpand : public OpBase { } #endif + bool PerNodeLimit(RTContext *ctx, size_t k) { + return !ctx->per_node_limit_.has_value() || + expand_counts_[k] <= ctx->per_node_limit_.value(); + } + int64_t GetFirstFromKthHop(RTContext *ctx, size_t k) { auto start_id = start_->PullVid(); relp_->path_.Clear(); relp_->path_.SetStart(start_id); if (k == 0) return start_id; - _InitializeEdgeIter(ctx, start_id, eits_[0]); - if (!eits_[0].IsValid()) return -1; + _InitializeEdgeIter(ctx, start_id, eits_[0], expand_counts_[0]); + if (!eits_[0].IsValid() || !PerNodeLimit(ctx, 0)) { + return -1; + } if (k == 1) { relp_->path_.Append(eits_[0].GetUid()); if (ctx->path_unique_) pattern_graph_->VisitedEdges().Add(eits_[0]); @@ -208,24 +216,28 @@ class VarLenExpand : public OpBase { if (!get_first || k != 1 || (ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(eits_[k - 1]))) { do { + expand_counts_[k - 1] += 1; eits_[k - 1].Next(); - } while (eits_[k - 1].IsValid() && ctx->path_unique_ && + } while (eits_[k - 1].IsValid() && PerNodeLimit(ctx, k - 1) && ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(eits_[k - 1])); } do { - if (!eits_[k - 1].IsValid()) { + if (!eits_[k - 1].IsValid() || !PerNodeLimit(ctx, k - 1)) { auto id = GetNextFromKthHop(ctx, k - 1, get_first); if (id < 0) return id; - _InitializeEdgeIter(ctx, id, eits_[k - 1]); + _InitializeEdgeIter(ctx, id, eits_[k - 1], expand_counts_[k - 1]); /* We have called get_next previously, mark get_first as * false. */ get_first = false; } while (ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(eits_[k - 1])) { + expand_counts_[k - 1] += 1; eits_[k - 1].Next(); } - } while (!eits_[k - 1].IsValid()); - if (!eits_[k - 1].IsValid()) return -1; + } while (!eits_[k - 1].IsValid() || !PerNodeLimit(ctx, k - 1)); + if (!eits_[k - 1].IsValid() || !PerNodeLimit(ctx, k - 1)) { + return -1; + } relp_->path_.Append(eits_[k - 1].GetUid()); if (ctx->path_unique_) pattern_graph_->VisitedEdges().Add(eits_[k - 1]); return eits_[k - 1].GetNbr(expand_direction_); @@ -258,19 +270,20 @@ class VarLenExpand : public OpBase { auto vid = GetFirstFromKthHop(ctx, hop_ - 1); if (vid < 0) return OP_REFRESH; if (hop_ > 1 && !eits_[hop_ - 2].IsValid()) CYPHER_INTL_ERR(); - _InitializeEdgeIter(ctx, vid, eits_[hop_ - 1]); + _InitializeEdgeIter(ctx, vid, eits_[hop_ - 1], expand_counts_[hop_ - 1]); // TODO(anyone) merge these code similiar to GetNextFromKthHop do { - if (!eits_[hop_ - 1].IsValid()) { + if (!eits_[hop_ - 1].IsValid() || !PerNodeLimit(ctx, hop_ - 1)) { auto v = GetNextFromKthHop(ctx, hop_ - 1, false); if (v < 0) return OP_REFRESH; - _InitializeEdgeIter(ctx, v, eits_[hop_ - 1]); + _InitializeEdgeIter(ctx, v, eits_[hop_ - 1], expand_counts_[hop_ - 1]); } while (ctx->path_unique_ && pattern_graph_->VisitedEdges().Contains(eits_[hop_ - 1])) { + expand_counts_[hop_ - 1] += 1; eits_[hop_ - 1].Next(); } - } while (!eits_[hop_ - 1].IsValid()); + } while (!eits_[hop_ - 1].IsValid() || !PerNodeLimit(ctx, hop_ - 1)); neighbor_->PushVid(eits_[hop_ - 1].GetNbr(expand_direction_)); relp_->path_.Append(eits_[hop_ - 1].GetUid()); // TODO(anyone) remove in last hop @@ -302,6 +315,7 @@ class VarLenExpand : public OpBase { bool collect_all_; ExpandTowards expand_direction_; std::vector &eits_; + std::vector expand_counts_; enum State { Uninitialized, /* ExpandAll wasn't initialized it. */ Resetted, /* ExpandAll was just restarted. */ @@ -333,6 +347,7 @@ class VarLenExpand : public OpBase { start_rec_idx_ = sit->second.id; nbr_rec_idx_ = dit->second.id; relp_rec_idx_ = rit->second.id; + expand_counts_.resize(eits_.size()); state_ = Uninitialized; } diff --git a/src/cypher/execution_plan/runtime_context.h b/src/cypher/execution_plan/runtime_context.h index 9408e3542e..217bed8c87 100644 --- a/src/cypher/execution_plan/runtime_context.h +++ b/src/cypher/execution_plan/runtime_context.h @@ -62,6 +62,7 @@ class RTContext : public SubmitQueryContext { std::unique_ptr txn_ = nullptr; std::unique_ptr result_info_ = nullptr; std::unique_ptr result_ = nullptr; + std::optional per_node_limit_ = std::nullopt; RTContext() = default; diff --git a/src/protobuf/ha.proto b/src/protobuf/ha.proto index 321ee64581..f3c3e454e5 100644 --- a/src/protobuf/ha.proto +++ b/src/protobuf/ha.proto @@ -504,6 +504,7 @@ message GraphQueryRequest { required bool result_in_json_format = 5; optional string graph = 6; optional double timeout = 7; + optional int64 per_node_limit = 8; }; message GraphQueryResult { diff --git a/src/server/state_machine.cpp b/src/server/state_machine.cpp index 89c5a6dbf7..18d20ede17 100644 --- a/src/server/state_machine.cpp +++ b/src/server/state_machine.cpp @@ -923,6 +923,9 @@ bool lgraph::StateMachine::ApplyGraphQueryRequest(const LGraphRequest* lgraph_re auto field_access = galaxy_->GetRoleFieldAccessLevel(user, req.graph()); cypher::RTContext ctx(this, galaxy_.get(), lgraph_req->token(), user, req.graph(), field_access); + if (req.has_per_node_limit()) { + ctx.per_node_limit_ = req.per_node_limit(); + } if (lgraph_req->has_is_write_op()) { is_write = lgraph_req->is_write_op(); } else {