From deffc3ab34861f718a47d3efdb13e187e6878964 Mon Sep 17 00:00:00 2001
From: xiaoxmeng <xiaoxmeng@fb.com>
Date: Mon, 11 Mar 2024 23:51:40 -0700
Subject: [PATCH] Fix a task early termination with group execution

---
 velox/exec/Task.cpp                       |  9 ++++-
 velox/exec/tests/GroupedExecutionTest.cpp | 46 +++++++++++++++++++++++
 2 files changed, 53 insertions(+), 2 deletions(-)

diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp
index dd58e2bda700..0410f9f06db7 100644
--- a/velox/exec/Task.cpp
+++ b/velox/exec/Task.cpp
@@ -1371,7 +1371,12 @@ bool Task::checkNoMoreSplitGroupsLocked() {
   // we should review the total number of drivers, which initially is set to
   // process all split groups, but in reality workers share split groups and
   // each worker processes only a part of them, meaning much less than all.
-  if (allNodesReceivedNoMoreSplitsMessageLocked()) {
+  //
+  // NOTE: we shall only do task finish check after the task has been started
+  // which initializes 'numDriversPerSplitGroup_', otherwise the task will
+  // finish early.
+  if ((numDriversPerSplitGroup_ != 0) &&
+      allNodesReceivedNoMoreSplitsMessageLocked()) {
     numTotalDrivers_ = seenSplitGroups_.size() * numDriversPerSplitGroup_ +
         numDriversUngrouped_;
     if (groupedPartitionedOutput_) {
@@ -1586,7 +1591,7 @@ bool Task::checkIfFinishedLocked() {
   // TODO Add support for terminating processing early in grouped execution.
   bool allFinished = numFinishedDrivers_ == numTotalDrivers_;
   if (!allFinished && isUngroupedExecution()) {
-    auto outputPipelineId = getOutputPipelineId();
+    const auto outputPipelineId = getOutputPipelineId();
     if (splitGroupStates_[kUngroupedGroupId].numFinishedOutputDrivers ==
         numDrivers(outputPipelineId)) {
       allFinished = true;
diff --git a/velox/exec/tests/GroupedExecutionTest.cpp b/velox/exec/tests/GroupedExecutionTest.cpp
index a7acd055c63a..be534714ff44 100644
--- a/velox/exec/tests/GroupedExecutionTest.cpp
+++ b/velox/exec/tests/GroupedExecutionTest.cpp
@@ -540,4 +540,50 @@ TEST_F(GroupedExecutionTest, groupedExecution) {
   EXPECT_EQ(numRead, numSplits * 10'000);
 }
 
+TEST_F(GroupedExecutionTest, allGroupSplitsReceivedBeforeTaskStart) {
+  // Create source file - we will read from it in 6 splits.
+  const size_t numSplits{6};
+  auto vectors = makeVectors(10, 1'000);
+  auto filePath = TempFilePath::create();
+  writeToFile(filePath->path, vectors);
+
+  CursorParameters params;
+  params.planNode = tableScanNode(ROW({}, {}));
+  params.maxDrivers = 1;
+  params.executionStrategy = core::ExecutionStrategy::kGrouped;
+  params.groupedExecutionLeafNodeIds.emplace(params.planNode->id());
+  params.numSplitGroups = 3;
+  params.numConcurrentSplitGroups = 1;
+
+  // Create the cursor with the task underneath. It is not started yet.
+  auto cursor = TaskCursor::create(params);
+  auto task = cursor->task();
+
+  // Add all split groups before start to ensure we can handle such cases.
+  task->addSplit("0", makeHiveSplitWithGroup(filePath->path, 0));
+  task->addSplit("0", makeHiveSplitWithGroup(filePath->path, 1));
+  task->addSplit("0", makeHiveSplitWithGroup(filePath->path, 2));
+  task->addSplit("0", makeHiveSplitWithGroup(filePath->path, 0));
+  task->addSplit("0", makeHiveSplitWithGroup(filePath->path, 1));
+  task->addSplit("0", makeHiveSplitWithGroup(filePath->path, 2));
+  task->noMoreSplits("0");
+
+  // Start task now.
+  cursor->start();
+  waitForFinishedDrivers(task, 3);
+  ASSERT_EQ(
+      getCompletedSplitGroups(task), std::unordered_set<int32_t>({0, 1, 2}));
+
+  // Make sure we've got the right number of rows.
+  int32_t numReadRows{0};
+  while (cursor->moveNext()) {
+    auto vector = cursor->current();
+    EXPECT_EQ(vector->childrenSize(), 0);
+    numReadRows += vector->size();
+  }
+
+  // Task must be finished at this stage.
+  ASSERT_EQ(task->state(), exec::TaskState::kFinished);
+  ASSERT_EQ(numSplits * 10'000, numReadRows);
+}
 } // namespace facebook::velox::exec::test