diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index f82769d4f21d..f2e0d584bed6 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -16,7 +16,6 @@ #include #include -#include #include #include #include "velox/common/testutil/TestValue.h" diff --git a/velox/exec/Merge.cpp b/velox/exec/Merge.cpp index 924817eb19cc..ae51ae2cddda 100644 --- a/velox/exec/Merge.cpp +++ b/velox/exec/Merge.cpp @@ -74,11 +74,20 @@ void Merge::initializeTreeOfLosers() { } BlockingReason Merge::isBlocked(ContinueFuture* future) { + TestValue::adjust("facebook::velox::exec::Merge::isBlocked", this); + auto reason = addMergeSources(future); if (reason != BlockingReason::kNotBlocked) { return reason; } + // NOTE: the task might terminate early which leaves empty sources. Once it + // happens, we shall simply mark the merge operator as finished. + if (sources_.empty()) { + finished_ = true; + return BlockingReason::kNotBlocked; + } + // No merging is needed if there is only one source. if (streams_.empty() && sources_.size() > 1) { initializeTreeOfLosers(); diff --git a/velox/exec/tests/MultiFragmentTest.cpp b/velox/exec/tests/MultiFragmentTest.cpp index e749a3f32ea5..2481be764b91 100644 --- a/velox/exec/tests/MultiFragmentTest.cpp +++ b/velox/exec/tests/MultiFragmentTest.cpp @@ -34,6 +34,7 @@ using namespace facebook::velox::connector::hive; using namespace facebook::velox::common::testutil; using namespace facebook::velox::memory; +using facebook::velox::common::testutil::TestValue; using facebook::velox::test::BatchMaker; class MultiFragmentTest : public HiveConnectorTestBase { @@ -1288,3 +1289,59 @@ TEST_F(MultiFragmentTest, taskTerminateWithPendingOutputBuffers) { ASSERT_EQ(task.use_count(), 1); task.reset(); } + +DEBUG_ONLY_TEST_F(MultiFragmentTest, mergeWithEarlyTermination) { + setupSources(10, 1000); + + std::vector> filePaths( + filePaths_.begin(), filePaths_.begin()); + + std::vector partialSortTaskIds; + auto sortTaskId = makeTaskId("orderby", 0); + partialSortTaskIds.push_back(sortTaskId); + auto planNodeIdGenerator = std::make_shared(); + auto partialSortPlan = PlanBuilder(planNodeIdGenerator) + .localMerge( + {"c0"}, + {PlanBuilder(planNodeIdGenerator) + .tableScan(rowType_) + .orderBy({"c0"}, true) + .planNode()}) + .partitionedOutput({}, 1) + .planNode(); + + auto partialSortTask = makeTask(sortTaskId, partialSortPlan, 1); + Task::start(partialSortTask, 1); + addHiveSplits(partialSortTask, filePaths); + + std::atomic blockMergeOnce{true}; + folly::EventCount mergeIsBlockedWait; + auto mergeIsBlockedWaitKet = mergeIsBlockedWait.prepareWait(); + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Merge::isBlocked", + std::function([&](const Operator* op) { + if (op->operatorType() != "MergeExchange") { + return; + } + if (!blockMergeOnce.exchange(false)) { + return; + } + mergeIsBlockedWait.wait(mergeIsBlockedWaitKet); + // Trigger early termination. + op->testingOperatorCtx()->task()->requestAbort(); + })); + + auto finalSortTaskId = makeTaskId("orderby", 1); + auto finalSortPlan = PlanBuilder() + .mergeExchange(partialSortPlan->outputType(), {"c0"}) + .partitionedOutput({}, 1) + .planNode(); + auto finalSortTask = makeTask(finalSortTaskId, finalSortPlan, 0); + Task::start(finalSortTask, 1); + addRemoteSplits(finalSortTask, partialSortTaskIds); + + mergeIsBlockedWait.notify(); + + ASSERT_TRUE(waitForTaskCompletion(partialSortTask.get(), 1'000'000'000)); + ASSERT_TRUE(waitForTaskAborted(finalSortTask.get(), 1'000'000'000)); +}