From d8870d72a9b47fea7099d54121fff7dda3a9233b Mon Sep 17 00:00:00 2001 From: xiaoxmeng Date: Wed, 20 Sep 2023 18:01:57 -0700 Subject: [PATCH] Add driver arbitration state check callback (#6656) Summary: This is a leftover from https://github.com/facebookincubator/velox/issues/6643. Pull Request resolved: https://github.com/facebookincubator/velox/pull/6656 Reviewed By: tanjialiang Differential Revision: D49469932 Pulled By: xiaoxmeng fbshipit-source-id: 16a46afc634e97259b0f20d380d90958d58fd3e5 --- .../memory/tests/SharedArbitratorTest.cpp | 14 +------- velox/exec/Driver.cpp | 13 ++++++++ velox/exec/Driver.h | 32 ++++++++++++------- 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/velox/common/memory/tests/SharedArbitratorTest.cpp b/velox/common/memory/tests/SharedArbitratorTest.cpp index 0d1e0acf8b1d..ed14558403e3 100644 --- a/velox/common/memory/tests/SharedArbitratorTest.cpp +++ b/velox/common/memory/tests/SharedArbitratorTest.cpp @@ -266,8 +266,6 @@ class FakeMemoryReclaimer : public MemoryReclaimer { auto* driver = driverThreadCtx->driverCtx.driver; ASSERT_TRUE(driver != nullptr); if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) { - // There is no need for arbitration if the associated task has already - // terminated. VELOX_FAIL("Terminate detected when entering suspension"); } } @@ -331,17 +329,7 @@ class SharedArbitrationTest : public exec::test::HiveConnectorTestBase { options.memoryPoolInitCapacity = memoryPoolInitCapacity; options.memoryPoolTransferCapacity = memoryPoolTransferCapacity; options.checkUsageLeak = true; - options.arbitrationStateCheckCb = [](MemoryPool& pool) { - const auto* driverThreadCtx = driverThreadContext(); - if (driverThreadCtx != nullptr) { - if (!driverThreadCtx->driverCtx.driver->state().isSuspended) { - LOG(ERROR) - << "false " - << driverThreadCtx->driverCtx.driver->state().toJsonString(); - } - ASSERT_TRUE(driverThreadCtx->driverCtx.driver->state().isSuspended); - } - }; + options.arbitrationStateCheckCb = driverArbitrationStateCheck; memoryManager_ = std::make_unique(options); ASSERT_EQ(memoryManager_->arbitrator()->kind(), "SHARED"); arbitrator_ = static_cast(memoryManager_->arbitrator()); diff --git a/velox/exec/Driver.cpp b/velox/exec/Driver.cpp index 205304257702..ab82a91528e6 100644 --- a/velox/exec/Driver.cpp +++ b/velox/exec/Driver.cpp @@ -828,6 +828,19 @@ std::string Driver::toJsonString() const { return folly::toPrettyJson(obj); } +void driverArbitrationStateCheck(memory::MemoryPool& pool) { + const auto* driverThreadCtx = driverThreadContext(); + if (driverThreadCtx != nullptr) { + Driver* driver = driverThreadCtx->driverCtx.driver; + if (!driver->state().isSuspended) { + VELOX_FAIL( + "Driver thread is not suspended under memory arbitration processing: {}, request memory pool: {}", + driver->toString(), + pool.name()); + } + } +} + SuspendedSection::SuspendedSection(Driver* driver) : driver_(driver) { if (driver->task()->enterSuspended(driver->state()) != StopReason::kNone) { VELOX_FAIL("Terminate detected when entering suspended section"); diff --git a/velox/exec/Driver.h b/velox/exec/Driver.h index e28fd3e4c6df..2b11f8674ddb 100644 --- a/velox/exec/Driver.h +++ b/velox/exec/Driver.h @@ -280,15 +280,15 @@ class Driver : public std::enable_shared_from_this { void initializeOperatorStats(std::vector& stats); - // Close operators and add operator stats to the task. + /// Close operators and add operator stats to the task. void closeOperators(); - // Returns true if all operators between the source and 'aggregation' are - // order-preserving and do not increase cardinality. + /// Returns true if all operators between the source and 'aggregation' are + /// order-preserving and do not increase cardinality. bool mayPushdownAggregation(Operator* aggregation) const; - // Returns a subset of channels for which there are operators upstream from - // filterSource that accept dynamically generated filters. + /// Returns a subset of channels for which there are operators upstream from + /// filterSource that accept dynamically generated filters. std::unordered_set canPushdownFilters( const Operator* filterSource, const std::vector& channels) const; @@ -300,7 +300,7 @@ class Driver : public std::enable_shared_from_this { /// Returns the Operator with 'operatorId' or nullptr if not found. Operator* findOperator(int32_t operatorId) const; - // Returns a list of all operators. + /// Returns a list of all operators. std::vector operators() const; std::string toString() const; @@ -315,8 +315,8 @@ class Driver : public std::enable_shared_from_this { return ctx_->task; } - // Updates the stats in Task and frees resources. Only called by Task for - // closing non-running Drivers. + /// Updates the stats in Task and frees resources. Only called by Task for + /// closing non-running Drivers. void closeByTask(); BlockingReason blockingReason() const { @@ -349,10 +349,10 @@ class Driver : public std::enable_shared_from_this { // position in the pipeline. void pushdownFilters(int operatorIndex); - /// If 'trackOperatorCpuUsage_' is true, returns initialized timer object to - /// track cpu and wall time of an operation. Returns null otherwise. - /// The delta CpuWallTiming object would be passes to 'func' upon - /// destruction of the timer. + // If 'trackOperatorCpuUsage_' is true, returns initialized timer object to + // track cpu and wall time of an operation. Returns null otherwise. + // The delta CpuWallTiming object would be passes to 'func' upon + // destruction of the timer. template std::unique_ptr> createDeltaCpuWallTimer(F&& func) { return trackOperatorCpuUsage_ @@ -397,6 +397,14 @@ class Driver : public std::enable_shared_from_this { friend struct DriverFactory; }; +/// Callback used by memory arbitration to check if a driver thread under memory +/// arbitration has been put in suspension state. This is to prevent arbitration +/// deadlock as the arbitrator might reclaim memory from the task of the driver +/// thread which is under arbitration. The task reclaim needs to wait for the +/// drivers to go off thread. A suspended driver thread is not counted as +/// running. +void driverArbitrationStateCheck(memory::MemoryPool& pool); + using OperatorSupplier = std::function< std::unique_ptr(int32_t operatorId, DriverCtx* ctx)>;