diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 41cb39d1b49a3..6fc28f822f8b4 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -1146,11 +1146,8 @@ std::string blockingReasonToString(BlockingReason reason) { return "kYield"; case BlockingReason::kWaitForArbitration: return "kWaitForArbitration"; - default: - break; } - VELOX_UNREACHABLE( - fmt::format("Unknown blocking reason {}", static_cast(reason))); + VELOX_UNREACHABLE(); } DriverThreadContext* driverThreadContext() { diff --git a/velox/exec/Task.cpp b/velox/exec/Task.cpp index 5a24f8e497240..dcf4aedba8f04 100644 --- a/velox/exec/Task.cpp +++ b/velox/exec/Task.cpp @@ -647,7 +647,9 @@ RowVectorPtr Task::next(ContinueFuture* future) { } VELOX_CHECK_EQ( - state_, TaskState::kRunning, "Task has already finished processing."); + static_cast(state_), + static_cast(kRunning), + "Task has already finished processing."); // On first call, create the drivers. if (driverFactories_.empty()) { @@ -682,11 +684,6 @@ RowVectorPtr Task::next(ContinueFuture* future) { } drivers_ = std::move(drivers); - driverBlockingStates_.reserve(drivers_.size()); - for (auto i = 0; i < drivers_.size(); ++i) { - driverBlockingStates_.emplace_back( - std::make_unique(drivers_[i].get())); - } } // Run drivers one at a time. If a driver blocks, continue running the other @@ -701,10 +698,7 @@ RowVectorPtr Task::next(ContinueFuture* future) { int runnableDrivers = 0; int blockedDrivers = 0; for (auto i = 0; i < numDrivers; ++i) { - // Holds a reference to driver for access as async task terminate might - // remove drivers from 'drivers_' slot. - auto driver = getDriver(i); - if (driver == nullptr) { + if (drivers_[i] == nullptr) { // This driver has finished processing. continue; } @@ -715,25 +709,16 @@ RowVectorPtr Task::next(ContinueFuture* future) { continue; } - ContinueFuture blockFuture = ContinueFuture::makeEmpty(); - if (driverBlockingStates_[i]->blocked(&blockFuture)) { - VELOX_CHECK(blockFuture.valid()); - futures[i] = std::move(blockFuture); - // This driver is still blocked. - ++blockedDrivers; - continue; - } ++runnableDrivers; ContinueFuture driverFuture = ContinueFuture::makeEmpty(); - auto result = driver->next(&driverFuture); - if (result != nullptr) { - VELOX_CHECK(!driverFuture.valid()); + auto result = drivers_[i]->next(&driverFuture); + if (result) { return result; } if (driverFuture.valid()) { - driverBlockingStates_[i]->setDriverFuture(driverFuture); + futures[i] = std::move(driverFuture); } if (error()) { @@ -743,7 +728,7 @@ RowVectorPtr Task::next(ContinueFuture* future) { if (runnableDrivers == 0) { if (blockedDrivers > 0) { - if (future == nullptr) { + if (!future) { VELOX_FAIL( "Cannot make progress as all remaining drivers are blocked and user are not expected to wait."); } else { @@ -753,7 +738,7 @@ RowVectorPtr Task::next(ContinueFuture* future) { notReadyFutures.emplace_back(std::move(continueFuture)); } } - *future = folly::collectAny(std::move(notReadyFutures)).unit(); + *future = folly::collectAll(std::move(notReadyFutures)).unit(); } } return nullptr; @@ -761,12 +746,6 @@ RowVectorPtr Task::next(ContinueFuture* future) { } } -std::shared_ptr Task::getDriver(uint32_t driverId) const { - VELOX_CHECK_LT(driverId, drivers_.size()); - std::unique_lock l(mutex_); - return drivers_[driverId]; -} - void Task::start(uint32_t maxDrivers, uint32_t concurrentSplitGroups) { facebook::velox::process::ThreadDebugInfo threadDebugInfo{ queryCtx()->queryId(), taskId_, nullptr}; @@ -1501,7 +1480,7 @@ void Task::noMoreSplits(const core::PlanNodeId& planNodeId) { } if (allFinished) { - terminate(TaskState::kFinished); + terminate(kFinished); } } @@ -3123,68 +3102,4 @@ void Task::MemoryReclaimer::abort( << "Timeout waiting for task to complete during query memory aborting."; } } - -void Task::DriverBlockingState::setDriverFuture(ContinueFuture& driverFuture) { - VELOX_CHECK(!blocked_); - { - std::lock_guard l(mutex_); - VELOX_CHECK(promises_.empty()); - VELOX_CHECK_NULL(error_); - blocked_ = true; - } - std::move(driverFuture) - .via(&folly::InlineExecutor::instance()) - .thenValue( - [&, driverHolder = driver_->shared_from_this()](auto&& /* unused */) { - std::vector> promises; - { - std::lock_guard l(mutex_); - VELOX_CHECK(blocked_); - VELOX_CHECK_NULL(error_); - promises = std::move(promises_); - blocked_ = false; - } - for (auto& promise : promises) { - promise->setValue(); - } - }) - .thenError( - folly::tag_t{}, - [&, driverHolder = driver_->shared_from_this()]( - std::exception const& e) { - std::lock_guard l(mutex_); - VELOX_CHECK(blocked_); - VELOX_CHECK_NULL(error_); - try { - VELOX_FAIL( - "A driver future from task {} was realized with error: {}", - driver_->task()->taskId(), - e.what()); - } catch (const VeloxException&) { - error_ = std::current_exception(); - } - blocked_ = false; - }); -} - -bool Task::DriverBlockingState::blocked(ContinueFuture* future) { - VELOX_CHECK_NOT_NULL(future); - std::lock_guard l(mutex_); - if (error_ != nullptr) { - std::rethrow_exception(error_); - } - if (!blocked_) { - VELOX_CHECK(promises_.empty()); - return false; - } - auto [blockPromise, blockFuture] = - makeVeloxContinuePromiseContract(fmt::format( - "DriverBlockingState {} from task {}", - driver_->driverCtx()->driverId, - driver_->task()->taskId())); - *future = std::move(blockFuture); - promises_.emplace_back( - std::make_unique(std::move(blockPromise))); - return true; -} } // namespace facebook::velox::exec diff --git a/velox/exec/Task.h b/velox/exec/Task.h index 12233aa7c2b88..3ba28f6a572d3 100644 --- a/velox/exec/Task.h +++ b/velox/exec/Task.h @@ -613,13 +613,13 @@ class Task : public std::enable_shared_from_this { /// realized when the last thread stops running for 'this'. This is used to /// mark cancellation by the user. ContinueFuture requestCancel() { - return terminate(TaskState::kCanceled); + return terminate(kCanceled); } /// Like requestCancel but sets end state to kAborted. This is for stopping /// Tasks due to failures of other parts of the query. ContinueFuture requestAbort() { - return terminate(TaskState::kAborted); + return terminate(kAborted); } void requestYield() { @@ -996,8 +996,6 @@ class Task : public std::enable_shared_from_this { // trace enabled. void maybeInitTrace(); - std::shared_ptr getDriver(uint32_t driverId) const; - // Universally unique identifier of the task. Used to identify the task when // calling TaskListener. const std::string uuid_; @@ -1069,39 +1067,6 @@ class Task : public std::enable_shared_from_this { std::vector> driverFactories_; std::vector> drivers_; - - // Tracks the blocking state for each driver under serialized execution mode. - class DriverBlockingState { - public: - explicit DriverBlockingState(const Driver* driver) : driver_(driver) { - VELOX_CHECK_NOT_NULL(driver_); - } - - /// Sets driver future by setting the continuation callback via inline - /// executor. - void setDriverFuture(ContinueFuture& diverFuture); - - /// Indicates if the associated driver is blocked or not. If blocked, - /// 'future' is set which becomes realized when the driver is unblocked. - /// - /// NOTE: the function throws if the driver has encountered error. - bool blocked(ContinueFuture* future); - - private: - const Driver* const driver_; - - mutable std::mutex mutex_; - // Indicates if the associated driver is blocked or not. - bool blocked_{false}; - // Sets the driver future error if not null. - std::exception_ptr error_{nullptr}; - // Promises to fulfill when the driver is unblocked. - std::vector> promises_; - }; - - // Tracks the driver blocking state under serialized execution mode. - std::vector> driverBlockingStates_; - // When Drivers are closed by the Task, there is a chance that race and/or // bugs can cause such Drivers to be held forever, in turn holding a pointer // to the Task making it a zombie Tasks. This vector is used to keep track of diff --git a/velox/exec/TaskStructs.h b/velox/exec/TaskStructs.h index 7d92366495897..3ddc147b65274 100644 --- a/velox/exec/TaskStructs.h +++ b/velox/exec/TaskStructs.h @@ -27,24 +27,8 @@ class MergeSource; class MergeJoinSource; struct Split; -#ifdef VELOX_ENABLE_BACKWARD_COMPATIBILITY -enum TaskState { - kRunning = 0, - kFinished = 1, - kCanceled = 2, - kAborted = 3, - kFailed = 4 -}; -#else /// Corresponds to Presto TaskState, needed for reporting query completion. -enum class TaskState : int { - kRunning = 0, - kFinished = 1, - kCanceled = 2, - kAborted = 3, - kFailed = 4 -}; -#endif +enum TaskState { kRunning, kFinished, kCanceled, kAborted, kFailed }; std::string taskStateString(TaskState state); @@ -155,13 +139,3 @@ struct SplitGroupState { }; } // namespace facebook::velox::exec - -template <> -struct fmt::formatter - : formatter { - auto format(facebook::velox::exec::TaskState state, format_context& ctx) - const { - return formatter::format( - facebook::velox::exec::taskStateString(state), ctx); - } -}; diff --git a/velox/exec/tests/LocalPartitionTest.cpp b/velox/exec/tests/LocalPartitionTest.cpp index 86cd3867bebf9..81cd0210f7fca 100644 --- a/velox/exec/tests/LocalPartitionTest.cpp +++ b/velox/exec/tests/LocalPartitionTest.cpp @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/PlanNodeStats.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/HiveConnectorTestBase.h" @@ -22,7 +21,6 @@ using namespace facebook::velox; using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; -using namespace facebook::velox::common::testutil; class LocalPartitionTest : public HiveConnectorTestBase { protected: @@ -30,10 +28,6 @@ class LocalPartitionTest : public HiveConnectorTestBase { HiveConnectorTestBase::SetUp(); } - void TearDown() override { - Operator::unregisterAllOperators(); - } - template FlatVectorPtr makeFlatSequence(T start, vector_size_t size) { return makeFlatVector(size, [start](auto row) { return start + row; }); @@ -541,7 +535,7 @@ TEST_F(LocalPartitionTest, earlyCancelation) { } // Wait for task to transition to final state. - waitForTaskCompletion(task, exec::TaskState::kCanceled); + waitForTaskCompletion(task, exec::kCanceled); // Make sure there is only one reference to Task left, i.e. no Driver is // blocked forever. @@ -577,7 +571,7 @@ TEST_F(LocalPartitionTest, producerError) { ASSERT_THROW(while (cursor->moveNext()) { ; }, VeloxException); // Wait for task to transition to failed state. - waitForTaskCompletion(task, exec::TaskState::kFailed); + waitForTaskCompletion(task, exec::kFailed); // Make sure there is only one reference to Task left, i.e. no Driver is // blocked forever. @@ -643,259 +637,3 @@ TEST_F(LocalPartitionTest, unionAllLocalExchange) { ")"); } } - -namespace { -using BlockingCallback = std::function; -using FinishCallback = std::function; - -class BlockingNode : public core::PlanNode { - public: - BlockingNode(const core::PlanNodeId& id, const core::PlanNodePtr& input) - : PlanNode(id), sources_{input} {} - - const RowTypePtr& outputType() const override { - return sources_[0]->outputType(); - } - - const std::vector>& sources() const override { - return sources_; - } - - std::string_view name() const override { - return "BlockingNode"; - } - - private: - void addDetails(std::stringstream& /* stream */) const override {} - std::vector sources_; -}; - -class BlockingOperator : public Operator { - public: - BlockingOperator( - DriverCtx* ctx, - int32_t id, - const std::shared_ptr& node, - const BlockingCallback& blockingCallback, - const FinishCallback& finishCallback) - : Operator(ctx, node->outputType(), id, node->id(), "BlockedNoFuture"), - blockingCallback_(blockingCallback), - finishCallback_(finishCallback) {} - - bool needsInput() const override { - return !noMoreInput_ && !input_; - } - - void addInput(RowVectorPtr input) override { - input_ = std::move(input); - } - - RowVectorPtr getOutput() override { - return std::move(input_); - } - - bool isFinished() override { - const bool finished = noMoreInput_ && input_ == nullptr; - finishCallback_(finished); - return finished; - } - - BlockingReason isBlocked(ContinueFuture* future) override { - return blockingCallback_(future); - } - - private: - const BlockingCallback blockingCallback_; - const FinishCallback finishCallback_; -}; - -class BlockingNodeFactory : public Operator::PlanNodeTranslator { - public: - explicit BlockingNodeFactory( - const BlockingCallback& blockingCallback, - const FinishCallback& finishCallback) - : blockingCallback_(blockingCallback), finishCallback_(finishCallback) {} - - std::unique_ptr toOperator( - DriverCtx* ctx, - int32_t id, - const core::PlanNodePtr& node) override { - return std::make_unique( - ctx, - id, - std::dynamic_pointer_cast(node), - blockingCallback_, - finishCallback_); - } - - std::optional maxDrivers( - const core::PlanNodePtr& /*unused*/) override { - return std::numeric_limits::max(); - } - - private: - const BlockingCallback blockingCallback_; - const FinishCallback finishCallback_; -}; -} // namespace - -TEST_F(LocalPartitionTest, unionAllLocalExchangeWithInterDependency) { - const auto data1 = makeRowVector({"d0"}, {makeFlatVector({"x"})}); - const auto data2 = makeRowVector({"e0"}, {makeFlatVector({"y"})}); - - for (bool serialExecutionMode : {false, true}) { - SCOPED_TRACE(fmt::format("serialExecutionMode {}", serialExecutionMode)); - Operator::unregisterAllOperators(); - - std::mutex mutex; - std::vector promises; - promises.reserve(2); - std::vector futures; - futures.reserve(2); - for (int i = 0; i < 2; ++i) { - auto [blockPromise, blockFuture] = makeVeloxContinuePromiseContract( - "unionAllLocalExchangeWithInterDependency"); - promises.push_back(std::move(blockPromise)); - futures.push_back(std::move(blockFuture)); - } - - std::atomic_uint32_t numBlocks{0}; - auto blockingCallback = [&](ContinueFuture* future) -> BlockingReason { - std::lock_guard l(mutex); - if (numBlocks >= 2) { - return BlockingReason::kNotBlocked; - } - *future = std::move(futures[numBlocks]); - ++numBlocks; - return BlockingReason::kWaitForConsumer; - }; - - auto finishCallback = [&](bool finished) { - if (!finished) { - return; - } - std::lock_guard l(mutex); - for (auto& promise : promises) { - if (!promise.isFulfilled()) { - promise.setValue(); - } - } - }; - - Operator::registerOperator(std::make_unique( - std::move(blockingCallback), std::move(finishCallback))); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .localPartitionRoundRobin( - {PlanBuilder(planNodeIdGenerator) - .values({data1}) - .project({"d0 as c0"}) - .addNode([](const core::PlanNodeId& id, - const core::PlanNodePtr& input) { - return std::make_shared(id, input); - }) - .planNode(), - PlanBuilder(planNodeIdGenerator) - .values({data2}) - .project({"e0 as c0"}) - .addNode([](const core::PlanNodeId& id, - const core::PlanNodePtr& input) { - return std::make_shared(id, input); - }) - .planNode()}) - .project({"length(c0)"}) - .planNode(); - - auto thread = std::thread([&]() { - AssertQueryBuilder(duckDbQueryRunner_) - .serialExecution(serialExecutionMode) - .plan(std::move(plan)) - .assertResults( - "SELECT length(c0) FROM (" - " SELECT * FROM (VALUES ('x')) as t1(c0) UNION ALL " - " SELECT * FROM (VALUES ('y')) as t2(c0)" - ")"); - }); - - while (numBlocks != 2) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); // NOLINT - } - promises[0].setValue(); - - thread.join(); - } -} - -TEST_F( - LocalPartitionTest, - taskErrorWithBlockedDriverFutureUnderSerializedExecutionMode) { - const auto data1 = makeRowVector({"d0"}, {makeFlatVector({"x"})}); - const auto data2 = makeRowVector({"e0"}, {makeFlatVector({"y"})}); - - Operator::unregisterAllOperators(); - - std::mutex mutex; - auto contract = makeVeloxContinuePromiseContract( - "driverFutureErrorUnderSerializedExecutionMode"); - - std::atomic_uint32_t numBlocks{0}; - auto blockingCallback = [&](ContinueFuture* future) -> BlockingReason { - std::lock_guard l(mutex); - if (numBlocks++ > 0) { - return BlockingReason::kNotBlocked; - } - *future = std::move(contract.second); - return BlockingReason::kWaitForConsumer; - }; - - auto finishCallback = [&](bool /*unused*/) {}; - - Operator::registerOperator(std::make_unique( - std::move(blockingCallback), std::move(finishCallback))); - - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .localPartitionRoundRobin( - {PlanBuilder(planNodeIdGenerator) - .values({data1}) - .project({"d0 as c0"}) - .addNode([](const core::PlanNodeId& id, - const core::PlanNodePtr& input) { - return std::make_shared(id, input); - }) - .planNode(), - PlanBuilder(planNodeIdGenerator) - .values({data2}) - .project({"e0 as c0"}) - .addNode([](const core::PlanNodeId& id, - const core::PlanNodePtr& input) { - return std::make_shared(id, input); - }) - .planNode()}) - .project({"length(c0)"}) - .planNode(); - - auto thread = std::thread([&]() { - VELOX_ASSERT_THROW( - AssertQueryBuilder(duckDbQueryRunner_) - .serialExecution(true) - .plan(std::move(plan)) - .assertResults( - "SELECT length(c0) FROM (" - " SELECT * FROM (VALUES ('x')) as t1(c0) UNION ALL " - " SELECT * FROM (VALUES ('y')) as t2(c0)" - ")"), - ""); - }); - - while (numBlocks < 2) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); // NOLINT - } - std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT - - auto tasks = Task::getRunningTasks(); - ASSERT_EQ(tasks.size(), 1); - tasks[0]->requestAbort().wait(); - thread.join(); -}