From ff831beeac8297d67e426046bc88fcd8c91f2939 Mon Sep 17 00:00:00 2001 From: Xiaoxuan Meng Date: Wed, 4 Dec 2024 21:07:15 -0800 Subject: [PATCH] feat: Fix the task hanging under serialized execution mode (#11747) Summary: This is resubmit of https://github.com/facebookincubator/velox/pull/11647 with Meta internal streaming use case fix Reviewed By: Yuhta, weijiadeng-uber Differential Revision: D66708173 --- velox/exec/Task.cpp | 99 ++++++++- velox/exec/Task.h | 35 ++++ velox/exec/tests/LocalPartitionTest.cpp | 257 ++++++++++++++++++++++++ 3 files changed, 385 insertions(+), 6 deletions(-) diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index 7b10948b53a08..af7a744013860 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -682,6 +682,11 @@ RowVectorPtr Task::next(ContinueFuture* future) { } drivers_ = std::move(drivers); + driverBlockingStates_.reserve(drivers_.size()); + for (auto i = 0; i < drivers_.size(); ++i) { + driverBlockingStates_.emplace_back( + std::make_unique(drivers_[i].get())); + } } // Run drivers one at a time. If a driver blocks, continue running the other @@ -696,7 +701,10 @@ RowVectorPtr Task::next(ContinueFuture* future) { int runnableDrivers = 0; int blockedDrivers = 0; for (auto i = 0; i < numDrivers; ++i) { - if (drivers_[i] == nullptr) { + // Holds a reference to driver for access as async task terminate might + // remove drivers from 'drivers_' slot. + auto driver = getDriver(i); + if (driver == nullptr) { // This driver has finished processing. continue; } @@ -707,16 +715,25 @@ RowVectorPtr Task::next(ContinueFuture* future) { continue; } + ContinueFuture blockFuture = ContinueFuture::makeEmpty(); + if (driverBlockingStates_[i]->blocked(&blockFuture)) { + VELOX_CHECK(blockFuture.valid()); + futures[i] = std::move(blockFuture); + // This driver is still blocked. + ++blockedDrivers; + continue; + } ++runnableDrivers; ContinueFuture driverFuture = ContinueFuture::makeEmpty(); - auto result = drivers_[i]->next(&driverFuture); - if (result) { + auto result = driver->next(&driverFuture); + if (result != nullptr) { + VELOX_CHECK(!driverFuture.valid()); return result; } if (driverFuture.valid()) { - futures[i] = std::move(driverFuture); + driverBlockingStates_[i]->setDriverFuture(driverFuture); } if (error()) { @@ -726,7 +743,7 @@ RowVectorPtr Task::next(ContinueFuture* future) { if (runnableDrivers == 0) { if (blockedDrivers > 0) { - if (!future) { + if (future == nullptr) { VELOX_FAIL( "Cannot make progress as all remaining drivers are blocked and user are not expected to wait."); } else { @@ -736,7 +753,7 @@ RowVectorPtr Task::next(ContinueFuture* future) { notReadyFutures.emplace_back(std::move(continueFuture)); } } - *future = folly::collectAll(std::move(notReadyFutures)).unit(); + *future = folly::collectAny(std::move(notReadyFutures)).unit(); } } return nullptr; @@ -792,6 +809,12 @@ void Task::start(uint32_t maxDrivers, uint32_t concurrentSplitGroups) { } } +std::shared_ptr Task::getDriver(uint32_t driverId) const { + VELOX_CHECK_LT(driverId, drivers_.size()); + std::unique_lock l(mutex_); + return drivers_[driverId]; +} + void Task::checkExecutionMode(ExecutionMode mode) { VELOX_CHECK_EQ(mode, mode_, "Inconsistent task execution mode."); } @@ -3100,4 +3123,68 @@ void Task::MemoryReclaimer::abort( << "Timeout waiting for task to complete during query memory aborting."; } } + +void Task::DriverBlockingState::setDriverFuture(ContinueFuture& driverFuture) { + VELOX_CHECK(!blocked_); + { + std::lock_guard l(mutex_); + VELOX_CHECK(promises_.empty()); + VELOX_CHECK_NULL(error_); + blocked_ = true; + } + std::move(driverFuture) + .via(&folly::InlineExecutor::instance()) + .thenValue([&, driverHolder = driver_->shared_from_this()]( + auto&& /* unused */) { + std::vector> promises; + { + std::lock_guard l(mutex_); + VELOX_CHECK(blocked_); + VELOX_CHECK_NULL(error_); + promises = std::move(promises_); + blocked_ = false; + } + for (auto& promise : promises) { + promise->setValue(); + } + }) + .thenError( + folly::tag_t{}, + [&, driverHolder = driver_->shared_from_this()]( + std::exception const& e) { + std::lock_guard l(mutex_); + VELOX_CHECK(blocked_); + VELOX_CHECK_NULL(error_); + try { + VELOX_FAIL( + "A driver future from task {} was realized with error: {}", + driver_->task()->taskId(), + e.what()); + } catch (const VeloxException&) { + error_ = std::current_exception(); + } + blocked_ = false; + }); +} + +bool Task::DriverBlockingState::blocked(ContinueFuture* future) { + VELOX_CHECK_NOT_NULL(future); + std::lock_guard l(mutex_); + if (error_ != nullptr) { + std::rethrow_exception(error_); + } + if (!blocked_) { + VELOX_CHECK(promises_.empty()); + return false; + } + auto [blockPromise, blockFuture] = + makeVeloxContinuePromiseContract(fmt::format( + "DriverBlockingState {} from task {}", + driver_->driverCtx()->driverId, + driver_->task()->taskId())); + *future = std::move(blockFuture); + promises_.emplace_back( + std::make_unique(std::move(blockPromise))); + return true; +} } // namespace facebook::velox::exec diff --git a/velox/exec/Task.h b/velox/exec/Task.h index d205b15178aff..12233aa7c2b88 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -996,6 +996,8 @@ class Task : public std::enable_shared_from_this { // trace enabled. void maybeInitTrace(); + std::shared_ptr getDriver(uint32_t driverId) const; + // Universally unique identifier of the task. Used to identify the task when // calling TaskListener. const std::string uuid_; @@ -1067,6 +1069,39 @@ class Task : public std::enable_shared_from_this { std::vector> driverFactories_; std::vector> drivers_; + + // Tracks the blocking state for each driver under serialized execution mode. + class DriverBlockingState { + public: + explicit DriverBlockingState(const Driver* driver) : driver_(driver) { + VELOX_CHECK_NOT_NULL(driver_); + } + + /// Sets driver future by setting the continuation callback via inline + /// executor. + void setDriverFuture(ContinueFuture& diverFuture); + + /// Indicates if the associated driver is blocked or not. If blocked, + /// 'future' is set which becomes realized when the driver is unblocked. + /// + /// NOTE: the function throws if the driver has encountered error. + bool blocked(ContinueFuture* future); + + private: + const Driver* const driver_; + + mutable std::mutex mutex_; + // Indicates if the associated driver is blocked or not. + bool blocked_{false}; + // Sets the driver future error if not null. + std::exception_ptr error_{nullptr}; + // Promises to fulfill when the driver is unblocked. + std::vector> promises_; + }; + + // Tracks the driver blocking state under serialized execution mode. + std::vector> driverBlockingStates_; + // When Drivers are closed by the Task, there is a chance that race and/or // bugs can cause such Drivers to be held forever, in turn holding a pointer // to the Task making it a zombie Tasks. This vector is used to keep track of diff --git a/velox/exec/tests/LocalPartitionTest.cpp b/velox/exec/tests/LocalPartitionTest.cpp index 1a8dc480c45d6..f5d6245b266c3 100644 --- a/velox/exec/tests/LocalPartitionTest.cpp +++ b/velox/exec/tests/LocalPartitionTest.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" @@ -671,3 +672,259 @@ TEST_F(LocalPartitionTest, unionAllLocalExchange) { ")"); } } + +namespace { +using BlockingCallback = std::function; +using FinishCallback = std::function; + +class BlockingNode : public core::PlanNode { + public: + BlockingNode(const core::PlanNodeId& id, const core::PlanNodePtr& input) + : PlanNode(id), sources_{input} {} + + const RowTypePtr& outputType() const override { + return sources_[0]->outputType(); + } + + const std::vector>& sources() const override { + return sources_; + } + + std::string_view name() const override { + return "BlockingNode"; + } + + private: + void addDetails(std::stringstream& /* stream */) const override {} + std::vector sources_; +}; + +class BlockingOperator : public Operator { + public: + BlockingOperator( + DriverCtx* ctx, + int32_t id, + const std::shared_ptr& node, + const BlockingCallback& blockingCallback, + const FinishCallback& finishCallback) + : Operator(ctx, node->outputType(), id, node->id(), "BlockedNoFuture"), + blockingCallback_(blockingCallback), + finishCallback_(finishCallback) {} + + bool needsInput() const override { + return !noMoreInput_ && !input_; + } + + void addInput(RowVectorPtr input) override { + input_ = std::move(input); + } + + RowVectorPtr getOutput() override { + return std::move(input_); + } + + bool isFinished() override { + const bool finished = noMoreInput_ && input_ == nullptr; + finishCallback_(finished); + return finished; + } + + BlockingReason isBlocked(ContinueFuture* future) override { + return blockingCallback_(future); + } + + private: + const BlockingCallback blockingCallback_; + const FinishCallback finishCallback_; +}; + +class BlockingNodeFactory : public Operator::PlanNodeTranslator { + public: + explicit BlockingNodeFactory( + const BlockingCallback& blockingCallback, + const FinishCallback& finishCallback) + : blockingCallback_(blockingCallback), finishCallback_(finishCallback) {} + + std::unique_ptr toOperator( + DriverCtx* ctx, + int32_t id, + const core::PlanNodePtr& node) override { + auto blockingNode = std::dynamic_pointer_cast(node); + if (blockingNode == nullptr) { + return nullptr; + } + return std::make_unique( + ctx, id, blockingNode, blockingCallback_, finishCallback_); + } + + std::optional maxDrivers( + const core::PlanNodePtr& /*unused*/) override { + return std::numeric_limits::max(); + } + + private: + const BlockingCallback blockingCallback_; + const FinishCallback finishCallback_; +}; +} // namespace + +TEST_F(LocalPartitionTest, unionAllLocalExchangeWithInterDependency) { + const auto data1 = makeRowVector({"d0"}, {makeFlatVector({"x"})}); + const auto data2 = makeRowVector({"e0"}, {makeFlatVector({"y"})}); + + for (bool serialExecutionMode : {false, true}) { + SCOPED_TRACE(fmt::format("serialExecutionMode {}", serialExecutionMode)); + Operator::unregisterAllOperators(); + + std::mutex mutex; + std::vector promises; + promises.reserve(2); + std::vector futures; + futures.reserve(2); + for (int i = 0; i < 2; ++i) { + auto [blockPromise, blockFuture] = makeVeloxContinuePromiseContract( + "unionAllLocalExchangeWithInterDependency"); + promises.push_back(std::move(blockPromise)); + futures.push_back(std::move(blockFuture)); + } + + std::atomic_uint32_t numBlocks{0}; + auto blockingCallback = [&](ContinueFuture* future) -> BlockingReason { + std::lock_guard l(mutex); + if (numBlocks >= 2) { + return BlockingReason::kNotBlocked; + } + *future = std::move(futures[numBlocks]); + ++numBlocks; + return BlockingReason::kWaitForConsumer; + }; + + auto finishCallback = [&](bool finished) { + if (!finished) { + return; + } + std::lock_guard l(mutex); + for (auto& promise : promises) { + if (!promise.isFulfilled()) { + promise.setValue(); + } + } + }; + + Operator::registerOperator(std::make_unique( + std::move(blockingCallback), std::move(finishCallback))); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator) + .values({data1}) + .project({"d0 as c0"}) + .addNode([](const core::PlanNodeId& id, + const core::PlanNodePtr& input) { + return std::make_shared(id, input); + }) + .planNode(), + PlanBuilder(planNodeIdGenerator) + .values({data2}) + .project({"e0 as c0"}) + .addNode([](const core::PlanNodeId& id, + const core::PlanNodePtr& input) { + return std::make_shared(id, input); + }) + .planNode()}) + .project({"length(c0)"}) + .planNode(); + + auto thread = std::thread([&]() { + AssertQueryBuilder(duckDbQueryRunner_) + .serialExecution(serialExecutionMode) + .plan(std::move(plan)) + .assertResults( + "SELECT length(c0) FROM (" + " SELECT * FROM (VALUES ('x')) as t1(c0) UNION ALL " + " SELECT * FROM (VALUES ('y')) as t2(c0)" + ")"); + }); + + while (numBlocks != 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT + } + promises[0].setValue(); + + thread.join(); + } +} + +TEST_F( + LocalPartitionTest, + taskErrorWithBlockedDriverFutureUnderSerializedExecutionMode) { + const auto data1 = makeRowVector({"d0"}, {makeFlatVector({"x"})}); + const auto data2 = makeRowVector({"e0"}, {makeFlatVector({"y"})}); + + Operator::unregisterAllOperators(); + + std::mutex mutex; + auto contract = makeVeloxContinuePromiseContract( + "driverFutureErrorUnderSerializedExecutionMode"); + + std::atomic_uint32_t numBlocks{0}; + auto blockingCallback = [&](ContinueFuture* future) -> BlockingReason { + std::lock_guard l(mutex); + if (numBlocks++ > 0) { + return BlockingReason::kNotBlocked; + } + *future = std::move(contract.second); + return BlockingReason::kWaitForConsumer; + }; + + auto finishCallback = [&](bool /*unused*/) {}; + + Operator::registerOperator(std::make_unique( + std::move(blockingCallback), std::move(finishCallback))); + + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .localPartitionRoundRobin( + {PlanBuilder(planNodeIdGenerator) + .values({data1}) + .project({"d0 as c0"}) + .addNode([](const core::PlanNodeId& id, + const core::PlanNodePtr& input) { + return std::make_shared(id, input); + }) + .planNode(), + PlanBuilder(planNodeIdGenerator) + .values({data2}) + .project({"e0 as c0"}) + .addNode([](const core::PlanNodeId& id, + const core::PlanNodePtr& input) { + return std::make_shared(id, input); + }) + .planNode()}) + .project({"length(c0)"}) + .planNode(); + + auto thread = std::thread([&]() { + VELOX_ASSERT_THROW( + AssertQueryBuilder(duckDbQueryRunner_) + .serialExecution(true) + .plan(std::move(plan)) + .assertResults( + "SELECT length(c0) FROM (" + " SELECT * FROM (VALUES ('x')) as t1(c0) UNION ALL " + " SELECT * FROM (VALUES ('y')) as t2(c0)" + ")"), + ""); + }); + + while (numBlocks < 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); // NOLINT + } + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + + auto tasks = Task::getRunningTasks(); + ASSERT_EQ(tasks.size(), 1); + tasks[0]->requestAbort().wait(); + thread.join(); +}