diff --git a/velox/common/memory/ArbitrationParticipant.cpp b/velox/common/memory/ArbitrationParticipant.cpp index 8ea63db8b485a..d48868cf2d88e 100644 --- a/velox/common/memory/ArbitrationParticipant.cpp +++ b/velox/common/memory/ArbitrationParticipant.cpp @@ -267,7 +267,7 @@ uint64_t ArbitrationParticipant::reclaim( if (targetBytes == 0) { return 0; } - std::lock_guard l(reclaimLock_); + ArbitrationOperationTimedLock l(reclaimMutex_); TestValue::adjust( "facebook::velox::memory::ArbitrationParticipant::reclaim", this); uint64_t reclaimedBytes{0}; @@ -320,7 +320,7 @@ uint64_t ArbitrationParticipant::shrinkLocked(bool reclaimAll) { uint64_t ArbitrationParticipant::abort( const std::exception_ptr& error) noexcept { - std::lock_guard l(reclaimLock_); + ArbitrationOperationTimedLock l(reclaimMutex_); return abortLocked(error); } @@ -353,13 +353,6 @@ uint64_t ArbitrationParticipant::abortLocked( return shrinkLocked(/*reclaimAll=*/true); } -bool ArbitrationParticipant::waitForReclaimOrAbort( - uint64_t maxWaitTimeNs) const { - std::unique_lock l( - reclaimLock_, std::chrono::nanoseconds(maxWaitTimeNs)); - return l.owns_lock(); -} - bool ArbitrationParticipant::hasRunningOp() const { std::lock_guard l(stateLock_); return runningOp_ != nullptr; @@ -408,4 +401,37 @@ std::string ArbitrationCandidate::toString() const { succinctBytes(reclaimableUsedCapacity), succinctBytes(reclaimableFreeCapacity)); } + +ArbitrationOperationTimedLock::ArbitrationOperationTimedLock( + std::timed_mutex& mutex) + : mutex_(mutex) { + auto arbitrationContext = memoryArbitrationContext(); + if (arbitrationContext == nullptr) { + mutex_.lock(); + return; + } + auto* operation = arbitrationContext->op; + if (operation == nullptr) { + VELOX_CHECK_EQ( + MemoryArbitrationContext::typeName(arbitrationContext->type), + MemoryArbitrationContext::typeName( + MemoryArbitrationContext::Type::kGlobal)); + mutex_.lock(); + return; + } + VELOX_CHECK_EQ( + MemoryArbitrationContext::typeName(arbitrationContext->type), + MemoryArbitrationContext::typeName( + MemoryArbitrationContext::Type::kLocal)); + if (!mutex_.try_lock_for(std::chrono::nanoseconds(operation->timeoutNs()))) { + VELOX_MEM_ARBITRATION_TIMEOUT(fmt::format( + "Memory arbitration lock timed out on memory pool: {} after running {}", + operation->participant()->name(), + succinctNanos(operation->executionTimeNs()))); + } +} + +ArbitrationOperationTimedLock::~ArbitrationOperationTimedLock() { + mutex_.unlock(); +} } // namespace facebook::velox::memory diff --git a/velox/common/memory/ArbitrationParticipant.h b/velox/common/memory/ArbitrationParticipant.h index 8d4c677ad94b4..6e9eb7a5dd4f2 100644 --- a/velox/common/memory/ArbitrationParticipant.h +++ b/velox/common/memory/ArbitrationParticipant.h @@ -25,6 +25,16 @@ #include "velox/common/memory/Memory.h" namespace facebook::velox::memory { + +#define VELOX_MEM_ARBITRATION_TIMEOUT(errorMessage) \ + _VELOX_THROW( \ + ::facebook::velox::VeloxRuntimeError, \ + ::facebook::velox::error_source::kErrorSourceRuntime.c_str(), \ + ::facebook::velox::error_code::kMemArbitrationTimeout.c_str(), \ + /* isRetriable */ true, \ + "{}", \ + errorMessage); + namespace test { class ArbitrationParticipantTestHelper; } @@ -32,6 +42,20 @@ class ArbitrationParticipantTestHelper; class ArbitrationOperation; class ScopedArbitrationParticipant; +/// Custom lock that keeps track of the time of the ongoing arbitration +/// operation while waiting for the lock. The lock will identify if it needs to +/// apply a wait timeout by checking arbitrationCtx thread local variable. If a +/// local arbitration is ongoing on the current locking thread, timeout will +/// automatically be applied. +class ArbitrationOperationTimedLock { + public: + explicit ArbitrationOperationTimedLock(std::timed_mutex& mutex); + ~ArbitrationOperationTimedLock(); + + private: + std::timed_mutex& mutex_; +}; + /// Manages the memory arbitration operations on a query memory pool. It also /// tracks the arbitration stats during the query memory pool's lifecycle. class ArbitrationParticipant @@ -154,9 +178,9 @@ class ArbitrationParticipant /// which ensures the liveness of underlying query memory pool. If the query /// memory pool is being destroyed, then this function returns std::nullopt. /// - // NOTE: it is not safe to directly access arbitration participant as it only - // holds a weak ptr to the query memory pool. Use 'lock()' to get a scoped - // arbitration participant for access. + /// NOTE: it is not safe to directly access arbitration participant as it only + /// holds a weak ptr to the query memory pool. Use 'lock()' to get a scoped + /// arbitration participant for access. std::optional lock(); /// Returns the corresponding query memory pool. @@ -223,11 +247,6 @@ class ArbitrationParticipant return aborted_; } - /// Invoked to wait for the pending memory reclaim or abort operation to - /// complete within a 'maxWaitTimeMs' time window. The function returns false - /// if the wait has timed out. - bool waitForReclaimOrAbort(uint64_t maxWaitTimeNs) const; - /// Invoked to start arbitration operation 'op'. The operation needs to wait /// for the prior arbitration operations to finish first before executing to /// ensure the serialized execution of arbitration operations from the same @@ -333,7 +352,7 @@ class ArbitrationParticipant tsan_atomic reclaimedBytes_{0}; tsan_atomic growBytes_{0}; - mutable std::timed_mutex reclaimLock_; + mutable std::timed_mutex reclaimMutex_; friend class ScopedArbitrationParticipant; friend class test::ArbitrationParticipantTestHelper; diff --git a/velox/common/memory/Memory.h b/velox/common/memory/Memory.h index a04ef03e8724b..f460d25ffee63 100644 --- a/velox/common/memory/Memory.h +++ b/velox/common/memory/Memory.h @@ -347,7 +347,6 @@ std::shared_ptr deprecatedAddDefaultLeafMemoryPool( /// using this method can get a pool that is shared with other threads. The goal /// is to minimize lock contention while supporting such use cases. /// -/// /// TODO: deprecate this API after all the use cases are able to manage the /// lifecycle of the allocated memory pools properly. MemoryPool& deprecatedSharedLeafPool(); diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index eac6d149f4856..90832ed22b7c1 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -449,8 +449,12 @@ bool MemoryArbitrator::Stats::operator<=(const Stats& other) const { return !(*this > other); } -MemoryArbitrationContext::MemoryArbitrationContext(const MemoryPool* requestor) - : type(Type::kLocal), requestorName(requestor->name()) {} +MemoryArbitrationContext::MemoryArbitrationContext( + const MemoryPool* requestor, + ArbitrationOperation* _op) + : type(Type::kLocal), requestorName(requestor->name()), op(_op) { + VELOX_CHECK_NOT_NULL(op); +} std::string MemoryArbitrationContext::typeName( MemoryArbitrationContext::Type type) { @@ -465,8 +469,10 @@ std::string MemoryArbitrationContext::typeName( } ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext( - const MemoryPool* requestor) - : savedArbitrationCtx_(arbitrationCtx), currentArbitrationCtx_(requestor) { + const MemoryPool* requestor, + ArbitrationOperation* op) + : savedArbitrationCtx_(arbitrationCtx), + currentArbitrationCtx_(requestor, op) { arbitrationCtx = ¤tArbitrationCtx_; } diff --git a/velox/common/memory/MemoryArbitrator.h b/velox/common/memory/MemoryArbitrator.h index 590506c1ccbcd..0120c6b91a0da 100644 --- a/velox/common/memory/MemoryArbitrator.h +++ b/velox/common/memory/MemoryArbitrator.h @@ -28,6 +28,7 @@ namespace facebook::velox::memory { class MemoryPool; +class ArbitrationOperation; using MemoryArbitrationStateCheckCB = std::function; @@ -398,11 +399,11 @@ class NonReclaimableSectionGuard { const bool oldNonReclaimableSectionValue_; }; -/// The memory arbitration context which is set on per-thread local variable by -/// memory arbitrator. It is used to indicate a running thread is under memory -/// arbitration processing or not. This helps to enable sanity check such as all -/// the memory reservations during memory arbitration should come from the -/// spilling memory pool. +/// The memory arbitration context which is set as per-thread local variable by +/// memory arbitrator. It is used to indicate if a running thread is under +/// memory arbitration. This helps to enable sanity check such as all the memory +/// reservations during memory arbitration should come from the spilling memory +/// pool. struct MemoryArbitrationContext { /// Defines the type of memory arbitration. enum class Type { @@ -420,20 +421,28 @@ struct MemoryArbitrationContext { /// global memory arbitration type. const std::string requestorName; - explicit MemoryArbitrationContext(const MemoryPool* requestor); + ArbitrationOperation* const op; - MemoryArbitrationContext() : type(Type::kGlobal) {} + MemoryArbitrationContext( + const MemoryPool* requestor, + ArbitrationOperation* _op); + + MemoryArbitrationContext() : type(Type::kGlobal), op(nullptr) {} }; /// Object used to set/restore the memory arbitration context when a thread is /// under memory arbitration processing. class ScopedMemoryArbitrationContext { public: - explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor); ScopedMemoryArbitrationContext(); + explicit ScopedMemoryArbitrationContext( const MemoryArbitrationContext* context); + ScopedMemoryArbitrationContext( + const MemoryPool* requestor, + ArbitrationOperation* op); + ~ScopedMemoryArbitrationContext(); private: diff --git a/velox/common/memory/SharedArbitrator.cpp b/velox/common/memory/SharedArbitrator.cpp index 6ca1601826f07..4c8f1426e9849 100644 --- a/velox/common/memory/SharedArbitrator.cpp +++ b/velox/common/memory/SharedArbitrator.cpp @@ -64,15 +64,6 @@ T getConfig( } return defaultValue; } - -#define VELOX_MEM_ARBITRATION_TIMEOUT(errorMessage) \ - _VELOX_THROW( \ - ::facebook::velox::VeloxRuntimeError, \ - ::facebook::velox::error_source::kErrorSourceRuntime.c_str(), \ - ::facebook::velox::error_code::kMemArbitrationTimeout.c_str(), \ - /* isRetriable */ true, \ - "{}", \ - errorMessage); } // namespace int64_t SharedArbitrator::ExtraConfig::reservedCapacity( @@ -284,7 +275,7 @@ SharedArbitrator::SharedArbitrator(const Config& config) void SharedArbitrator::shutdown() { { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); VELOX_CHECK(globalArbitrationWaiters_.empty()); if (hasShutdownLocked()) { return; @@ -436,7 +427,7 @@ void SharedArbitrator::addPool(const std::shared_ptr& pool) { auto scopedParticipant = newParticipant->lock().value(); std::vector arbitrationWaiters; { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); const uint64_t minBytesToReserve = std::min( scopedParticipant->maxCapacity(), scopedParticipant->minCapacity()); const uint64_t maxBytesToReserve = std::max( @@ -589,7 +580,7 @@ uint64_t SharedArbitrator::allocateCapacity( uint64_t requestBytes, uint64_t maxAllocateBytes, uint64_t minAllocateBytes) { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); return allocateCapacityLocked( participantId, requestBytes, maxAllocateBytes, minAllocateBytes); } @@ -745,9 +736,10 @@ bool SharedArbitrator::growCapacity(ArbitrationOperation& op) { participantConfig_.minReclaimBytes) { return false; } - - // NOTE: if global memory arbitration is not enabled, we will try to - // reclaim from the participant itself before failing this operation. + // After failing to acquire enough free capacity to fulfil this capacity + // growth request, we will try to reclaim from the participant itself before + // failing this operation. We only do this if global memory arbitration is + // not enabled. reclaim( op.participant(), op.requestBytes(), @@ -768,7 +760,7 @@ bool SharedArbitrator::startAndWaitGlobalArbitration(ArbitrationOperation& op) { ContinueFuture arbitrationWaitFuture{ContinueFuture::makeEmpty()}; uint64_t allocatedBytes{0}; { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); allocatedBytes = allocateCapacityLocked( op.participant()->id(), op.requestBytes(), @@ -838,7 +830,7 @@ void SharedArbitrator::globalArbitrationMain() { VELOX_MEM_LOG(INFO) << "Global arbitration controller started"; while (true) { { - std::unique_lock l(stateLock_); + std::unique_lock l(stateMutex_); globalArbitrationThreadCv_.wait(l, [&] { return hasShutdownLocked() || !globalArbitrationWaiters_.empty(); }); @@ -918,7 +910,7 @@ void SharedArbitrator::runGlobalArbitration() { uint64_t SharedArbitrator::getGlobalArbitrationTarget() { uint64_t targetBytes{0}; - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); for (const auto& waiter : globalArbitrationWaiters_) { targetBytes += waiter.second->op->maxGrowBytes(); } @@ -929,14 +921,6 @@ uint64_t SharedArbitrator::getGlobalArbitrationTarget() { capacity_ * globalArbitrationMemoryReclaimPct_ / 100, targetBytes); } -void SharedArbitrator::getGrowTargets( - ArbitrationOperation& op, - uint64_t& maxGrowTarget, - uint64_t& minGrowTarget) { - op.participant()->getGrowTargets( - op.requestBytes(), maxGrowTarget, minGrowTarget); -} - void SharedArbitrator::checkIfAborted(ArbitrationOperation& op) { if (op.participant()->aborted()) { VELOX_MEM_POOL_ABORTED( @@ -1141,9 +1125,7 @@ uint64_t SharedArbitrator::reclaimUsedMemoryBySpill( reclaimedBytes += reclaimResult->reclaimedBytes; } VELOX_CHECK_LE(prevReclaimedBytes, reclaimedUsedBytes_); - // NOTE: there might be concurrent local spill or spill triggered by - // external shrink. - return std::max(reclaimedBytes, reclaimedUsedBytes_ - prevReclaimedBytes); + return reclaimedBytes; } uint64_t SharedArbitrator::reclaimUsedMemoryByAbort(bool force) { @@ -1254,12 +1236,12 @@ void SharedArbitrator::freeCapacity(uint64_t bytes) { if (FOLLY_UNLIKELY(bytes == 0)) { return; } - std::vector resumes; + std::vector globalArbitrationWaitResumes; { - std::lock_guard l(stateLock_); - freeCapacityLocked(bytes, resumes); + std::lock_guard l(stateMutex_); + freeCapacityLocked(bytes, globalArbitrationWaitResumes); } - for (auto& resume : resumes) { + for (auto& resume : globalArbitrationWaitResumes) { resume.setValue(); } } @@ -1304,7 +1286,7 @@ void SharedArbitrator::resumeGlobalArbitrationWaitersLocked( void SharedArbitrator::removeGlobalArbitrationWaiter(uint64_t id) { ContinuePromise resume = ContinuePromise::makeEmpty(); { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); auto it = globalArbitrationWaiters_.find(id); if (it != globalArbitrationWaiters_.end()) { VELOX_CHECK_EQ(it->second->allocatedBytes, 0); @@ -1326,7 +1308,7 @@ void SharedArbitrator::freeReservedCapacityLocked(uint64_t& bytes) { } MemoryArbitrator::Stats SharedArbitrator::stats() const { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); return statsLocked(); } @@ -1346,7 +1328,7 @@ MemoryArbitrator::Stats SharedArbitrator::statsLocked() const { } std::string SharedArbitrator::toString() const { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); return fmt::format( "ARBITRATOR[{} CAPACITY[{}] {}]", kind_, @@ -1359,7 +1341,7 @@ SharedArbitrator::ScopedArbitration::ScopedArbitration( ArbitrationOperation* operation) : arbitrator_(arbitrator), operation_(operation), - arbitrationCtx_(operation->participant()->pool()), + arbitrationCtx_(operation->participant()->pool(), operation), startTime_(std::chrono::steady_clock::now()) { VELOX_CHECK_NOT_NULL(arbitrator_); VELOX_CHECK_NOT_NULL(operation_); diff --git a/velox/common/memory/SharedArbitrator.h b/velox/common/memory/SharedArbitrator.h index 1ed5569ca9dd8..1f39ddb3858db 100644 --- a/velox/common/memory/SharedArbitrator.h +++ b/velox/common/memory/SharedArbitrator.h @@ -304,6 +304,9 @@ class SharedArbitrator : public memory::MemoryArbitrator { const std::chrono::steady_clock::time_point startTime_; }; + // The scoped object to cover the global arbitration execution. It ensures + // the setups and teardowns of 'arbitrator' global arbitration state and + // thread_local 'arbitrationCtx' global context. class GlobalArbitrationSection { public: explicit GlobalArbitrationSection(SharedArbitrator* arbitrator); @@ -311,11 +314,13 @@ class SharedArbitrator : public memory::MemoryArbitrator { private: SharedArbitrator* const arbitrator_; + + // Default to global arbitration context. const memory::ScopedMemoryArbitrationContext arbitrationCtx_{}; }; FOLLY_ALWAYS_INLINE void checkRunning() { - std::lock_guard l(stateLock_); + std::lock_guard l(stateMutex_); VELOX_CHECK(!hasShutdownLocked(), "SharedArbitrator is not running"); } @@ -342,13 +347,6 @@ class SharedArbitrator : public memory::MemoryArbitrator { // success. bool growCapacity(ArbitrationOperation& op); - // Gets the mim/max memory capacity growth targets for 'op' once after it - // starts to run. - void getGrowTargets( - ArbitrationOperation& op, - uint64_t& maxGrowTarget, - uint64_t& minGrowTarget); - // Invoked to start execution of 'op'. It waits for the serialized execution // on the same arbitration participant and returns when 'op' is ready to run. void startArbitration(ArbitrationOperation* op); @@ -408,10 +406,10 @@ class SharedArbitrator : public memory::MemoryArbitrator { // Invoked to get the global arbitration target in bytes. uint64_t getGlobalArbitrationTarget(); - // Invoked to run global arbitration to reclaim free or used memory from the - // other queries. The global arbitration run is protected by the exclusive - // lock of 'arbitrationLock_' for serial execution mode. The function returns - // true on success, false on failure. + // Invoked to run global arbitration to reclaim free or used memory from other + // queries. The global arbitration run is protected by the exclusive lock of + // 'arbitrationLock_' for serial execution mode. The function returns true on + // success, false on failure. bool startAndWaitGlobalArbitration(ArbitrationOperation& op); // Invoked to get stats of candidate participants for arbitration. If @@ -430,21 +428,30 @@ class SharedArbitrator : public memory::MemoryArbitrator { std::vector& candidates); // Invoked to reclaim the specified used memory capacity from one or more - // participants in parallel by spilling. 'reclaimedParticipants' tracks the - // participants that have been reclaimed by spill across multiple global - // arbitration runs. 'failedParticipants' tracks the participants that have - // failed to reclaim any memory by spill. This could happen if there is some - // unknown bug or limitation in specific spillable operator implementation. - // Correspondingly, the global arbitration shall skip reclaiming from those - // participants in next arbitration round. 'allParticipantsReclaimed' - // indicates if all participants have been reclaimed by spill so far. It is - // used by gllobal arbitration to decide if need to switch to abort to reclaim - // used memory in the next arbitration round. The function returns the - // actually reclaimed used capacity in bytes. + // participants in parallel by spilling. + // + // 'reclaimedParticipants' keeps track of the participants that have been + // reclaimed by spilling. It will be taken as input to avoid reclaiming from + // these participants again. It will also be updated when additional + // participants are reclaimed. From caller's perspective, it should be kept + // and provided from across multiple global arbitration runs. + // + // 'failedParticipants' keeps track of the participants that have failed to + // reclaim any memory by spilling. This could happen if there is some unknown + // bug or limitation in specific spillable operator implementation. It will be + // taken as input to avoid reclaiming from these participants again. It will + // also be updated when additional participants fail to be reclaimed any + // memory. From caller's perspective, it should be kept and provided from + // across multiple global arbitration runs. // - // NOTE: the function sort participants based on their reclaimable used memory - // capacity, and reclaim from participants with larger reclaimable used memory - // first. + // 'allParticipantsReclaimed' returns if all participants have been + // reclaimed by spilling so far. It is used by gllobal arbitration to decide + // if need to switch to abort to reclaim used memory in the next arbitration + // round. The function returns the actually reclaimed used capacity in bytes. + // + // NOTE: the function sorts participants based on their reclaimable used + // memory capacity, and reclaims from participants with larger reclaimable + // used memory first. uint64_t reclaimUsedMemoryBySpill( uint64_t targetBytes, std::unordered_set& reclaimedParticipants, @@ -565,11 +572,6 @@ class SharedArbitrator : public memory::MemoryArbitrator { // corresponding operator's runtime stats. void incrementLocalArbitrationCount(); - size_t numParticipants() const { - std::shared_lock l(participantLock_); - return participants_.size(); - } - Stats statsLocked() const; void updateMemoryReclaimStats( @@ -605,9 +607,8 @@ class SharedArbitrator : public memory::MemoryArbitrator { std::unordered_map> participants_; - // Lock used to protect the arbitrator internal state. - mutable std::mutex stateLock_; - + // Mutex used to protect the arbitrator internal state. + mutable std::mutex stateMutex_; State state_{State::kRunning}; tsan_atomic freeReservedCapacity_{0}; @@ -627,7 +628,7 @@ class SharedArbitrator : public memory::MemoryArbitrator { std::unique_ptr globalArbitrationController_; // Signal used to wakeup 'globalArbitrationController_' to run global // arbitration on-demand. - std::condition_variable globalArbitrationThreadCv_; + std::condition_variable_any globalArbitrationThreadCv_; // Records an arbitration operation waiting for global memory arbitration. struct ArbitrationWait { @@ -640,7 +641,7 @@ class SharedArbitrator : public memory::MemoryArbitrator { }; // The map of global arbitration waiters. The key is the arbitration operation - // id which is set to id the of the corresponding arbitration participant. + // id which is set to the id of the corresponding arbitration participant. // This ensures to satisfy the arbitration request in the order of the age of // arbitration participants with old participants being served first. std::map globalArbitrationWaiters_; diff --git a/velox/common/memory/tests/ArbitrationParticipantTest.cpp b/velox/common/memory/tests/ArbitrationParticipantTest.cpp index 26330ec03b598..27748971aec7d 100644 --- a/velox/common/memory/tests/ArbitrationParticipantTest.cpp +++ b/velox/common/memory/tests/ArbitrationParticipantTest.cpp @@ -1484,81 +1484,6 @@ DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, reclaimLock) { ASSERT_EQ(scopedParticipant->stats().reclaimedBytes, 32 << 20); } -DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, waitForReclaimOrAbort) { - struct { - uint64_t waitTimeNs; - bool pendingReclaim; - uint64_t reclaimWaitMs{0}; - bool expectedTimeout; - - std::string debugString() const { - return fmt::format( - "waitTime {}, pendingReclaim {}, reclaimWait {}, expectedTimeout {}", - succinctNanos(waitTimeNs), - pendingReclaim, - succinctMillis(reclaimWaitMs), - expectedTimeout); - } - } testSettings[] = { - {0, true, 1'000, true}, - {0, false, 1'000, true}, - {1'000'000'000'000UL, true, 1'000, false}, - {1'000'000'000'000UL, true, 1'000, false}}; - - for (const auto& testData : testSettings) { - SCOPED_TRACE(testData.debugString()); - - std::atomic_bool reclaimWaitFlag{false}; - folly::EventCount reclaimWait; - SCOPED_TESTVALUE_SET( - "facebook::velox::memory::ArbitrationParticipant::reclaim", - std::function( - ([&](ArbitrationParticipant* /*unused*/) { - reclaimWaitFlag = true; - reclaimWait.notifyAll(); - std::this_thread::sleep_for( - std::chrono::milliseconds(testData.reclaimWaitMs)); // NOLINT - }))); - - SCOPED_TESTVALUE_SET( - "facebook::velox::memory::ArbitrationParticipant::abortLocked", - std::function( - ([&](ArbitrationParticipant* /*unused*/) { - reclaimWaitFlag = true; - reclaimWait.notifyAll(); - std::this_thread::sleep_for( - std::chrono::milliseconds(testData.reclaimWaitMs)); // NOLINT - }))); - - auto task = createTask(kMemoryCapacity); - const auto config = arbitrationConfig(); - auto participant = - ArbitrationParticipant::create(10, task->pool(), &config); - task->allocate(MB); - auto scopedParticipant = participant->lock().value(); - - std::thread reclaimThread([&]() { - if (testData.pendingReclaim) { - memory::MemoryReclaimer::Stats stats; - ASSERT_EQ( - scopedParticipant->reclaim(MB, 1'000'000'000'000UL, stats), MB); - } else { - const std::string abortReason = "test abort"; - try { - VELOX_FAIL(abortReason); - } catch (const VeloxRuntimeError& e) { - ASSERT_EQ(scopedParticipant->abort(std::current_exception()), MB); - } - } - }); - reclaimWait.await([&]() { return reclaimWaitFlag.load(); }); - ASSERT_EQ( - scopedParticipant->waitForReclaimOrAbort(testData.waitTimeNs), - !testData.expectedTimeout); - reclaimThread.join(); - } -} - // This test verifies the aborted returns true until the participant has been // aborted. DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, abortedCheck) { @@ -1950,5 +1875,86 @@ TEST_F(ArbitrationParticipantTest, arbitrationOperationState) { static_cast(10)), "unknown state: 10"); } + +TEST_F(ArbitrationParticipantTest, arbitrationOperationTimedLock) { + auto participantPool = manager_->addRootPool("arbitrationOperationTimedLock"); + auto config = ArbitrationParticipant::Config(0, 1024, 0, 0, 0, 0, 128, 512); + auto participant = ArbitrationParticipant::create( + folly::Random::rand64(), participantPool, &config); + + auto createLockHolderThread = [](std::timed_mutex& mutex, + uint64_t lockHoldTimeNs, + folly::EventCount& lockWait, + std::atomic_bool& lockWaitFlag) { + return std::thread([&, sleepNs = lockHoldTimeNs]() { + std::lock_guard l(mutex); + lockWaitFlag = false; + lockWait.notifyAll(); + std::this_thread::sleep_for(std::chrono::nanoseconds(sleepNs)); + }); + }; + + struct TestData { + std::string type; + uint64_t lockHoldTimeNs; + uint64_t opTimeoutNs; + }; + + std::timed_mutex mutex; + std::vector testDataVec{ + {"local", 1'000'000'000UL, 2'000'000'000UL}, + {"local", 2'000'000'000UL, 1'000'000'000UL}, + {"global", 1'000'000'000UL, 2'000'000'000UL}, + {"global", 2'000'000'000UL, 1'000'000'000UL}, + {"none", 1'000'000'000UL, 2'000'000'000UL}}; + + for (auto& testData : testDataVec) { + ScopedArbitrationParticipant scopedArbitrationParticipant( + participant, participantPool); + ArbitrationOperation operation( + std::move(scopedArbitrationParticipant), 1024, testData.opTimeoutNs); + if (testData.type == "local") { + MemoryArbitrationContext ctx(participantPool.get(), &operation); + ScopedMemoryArbitrationContext scopedCtx(&ctx); + + folly::EventCount lockWait; + std::atomic_bool lockWaitFlag{true}; + auto lockHolder = createLockHolderThread( + mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag); + std::unique_ptr timedLock{nullptr}; + lockWait.await([&]() { return !lockWaitFlag.load(); }); + if (testData.lockHoldTimeNs < testData.opTimeoutNs) { + timedLock = std::make_unique(mutex); + ASSERT_FALSE(mutex.try_lock()); + } else { + VELOX_ASSERT_THROW( + std::make_unique(mutex), + "Memory arbitration lock timed out"); + } + lockHolder.join(); + } else if (testData.type == "global") { + MemoryArbitrationContext ctx; + ScopedMemoryArbitrationContext scopedCtx(&ctx); + + folly::EventCount lockWait; + std::atomic_bool lockWaitFlag{true}; + auto lockHolder = createLockHolderThread( + mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag); + lockWait.await([&]() { return !lockWaitFlag.load(); }); + ArbitrationOperationTimedLock timedLock(mutex); + ASSERT_FALSE(mutex.try_lock()); + lockHolder.join(); + } else { + folly::EventCount lockWait; + std::atomic_bool lockWaitFlag{true}; + auto lockHolder = createLockHolderThread( + mutex, testData.lockHoldTimeNs, lockWait, lockWaitFlag); + lockWait.await([&]() { return !lockWaitFlag.load(); }); + ArbitrationOperationTimedLock timedLock(mutex); + ASSERT_FALSE(mutex.try_lock()); + lockHolder.join(); + } + } +} } // namespace } // namespace facebook::velox::memory diff --git a/velox/common/memory/tests/MemoryArbitratorTest.cpp b/velox/common/memory/tests/MemoryArbitratorTest.cpp index dbf1907e6912f..881cff43b843f 100644 --- a/velox/common/memory/tests/MemoryArbitratorTest.cpp +++ b/velox/common/memory/tests/MemoryArbitratorTest.cpp @@ -23,6 +23,7 @@ #include "velox/common/memory/Memory.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" using namespace ::testing; @@ -989,13 +990,19 @@ TEST_F(MemoryReclaimerTest, arbitrationContext) { ASSERT_FALSE(isSpillMemoryPool(leafChild2.get())); ASSERT_TRUE(memoryArbitrationContext() == nullptr); { - ScopedMemoryArbitrationContext arbitrationContext(leafChild1.get()); + auto arbitrationStructs = + test::ArbitrationTestStructs::createArbitrationTestStructs(leafChild1); + ScopedMemoryArbitrationContext arbitrationContext( + leafChild1.get(), arbitrationStructs.operation.get()); ASSERT_TRUE(memoryArbitrationContext() != nullptr); ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild1->name()); } ASSERT_TRUE(memoryArbitrationContext() == nullptr); { - ScopedMemoryArbitrationContext arbitrationContext(leafChild2.get()); + auto arbitrationStructs = + test::ArbitrationTestStructs::createArbitrationTestStructs(leafChild2); + ScopedMemoryArbitrationContext arbitrationContext( + leafChild2.get(), arbitrationStructs.operation.get()); ASSERT_TRUE(memoryArbitrationContext() != nullptr); ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild2->name()); } @@ -1003,13 +1010,21 @@ TEST_F(MemoryReclaimerTest, arbitrationContext) { std::thread nonAbitrationThread([&]() { ASSERT_TRUE(memoryArbitrationContext() == nullptr); { - ScopedMemoryArbitrationContext arbitrationContext(leafChild1.get()); + auto arbitrationStructs = + test::ArbitrationTestStructs::createArbitrationTestStructs( + leafChild1); + ScopedMemoryArbitrationContext arbitrationContext( + leafChild1.get(), arbitrationStructs.operation.get()); ASSERT_TRUE(memoryArbitrationContext() != nullptr); ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild1->name()); } ASSERT_TRUE(memoryArbitrationContext() == nullptr); { - ScopedMemoryArbitrationContext arbitrationContext(leafChild2.get()); + auto arbitrationStructs = + test::ArbitrationTestStructs::createArbitrationTestStructs( + leafChild2); + ScopedMemoryArbitrationContext arbitrationContext( + leafChild2.get(), arbitrationStructs.operation.get()); ASSERT_TRUE(memoryArbitrationContext() != nullptr); ASSERT_EQ(memoryArbitrationContext()->requestorName, leafChild2->name()); } diff --git a/velox/common/memory/tests/MemoryPoolTest.cpp b/velox/common/memory/tests/MemoryPoolTest.cpp index 1c5667db0e203..fa0d94f92239a 100644 --- a/velox/common/memory/tests/MemoryPoolTest.cpp +++ b/velox/common/memory/tests/MemoryPoolTest.cpp @@ -26,6 +26,7 @@ #include "velox/common/memory/MemoryPool.h" #include "velox/common/memory/MmapAllocator.h" #include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" DECLARE_bool(velox_memory_leak_check_enabled); @@ -3887,7 +3888,10 @@ TEST_P(MemoryPoolTest, overuseUnderArbitration) { ASSERT_FALSE(child->maybeReserve(2 * kMaxSize)); ASSERT_EQ(child->usedBytes(), 0); ASSERT_EQ(child->reservedBytes(), 0); - ScopedMemoryArbitrationContext scopedMemoryArbitration(child.get()); + auto arbitrationTestStructs = + test::ArbitrationTestStructs::createArbitrationTestStructs(root); + ScopedMemoryArbitrationContext scopedMemoryArbitration( + root.get(), arbitrationTestStructs.operation.get()); ASSERT_TRUE(underMemoryArbitration()); ASSERT_TRUE(child->maybeReserve(2 * kMaxSize)); ASSERT_EQ(child->usedBytes(), 0); diff --git a/velox/common/memory/tests/MockSharedArbitratorTest.cpp b/velox/common/memory/tests/MockSharedArbitratorTest.cpp index 1a36784bc887a..2acdb3f6cac2d 100644 --- a/velox/common/memory/tests/MockSharedArbitratorTest.cpp +++ b/velox/common/memory/tests/MockSharedArbitratorTest.cpp @@ -42,7 +42,6 @@ using namespace facebook::velox::exec; using namespace facebook::velox::exec::test; namespace facebook::velox::memory { -namespace { // Class to write runtime stats in the tests to the stats container. class TestRuntimeStatWriter : public BaseRuntimeStatWriter { public: @@ -80,7 +79,7 @@ using ReclaimInjectionCallback = std::function; using ArbitrationInjectionCallback = std::function; -struct Allocation { +struct AllocatedBuffer { void* buffer{nullptr}; size_t size{0}; }; @@ -298,7 +297,7 @@ class MockMemoryOperator { } void free() { - Allocation allocation; + AllocatedBuffer allocation; { std::lock_guard l(mu_); if (allocations_.empty()) { @@ -327,7 +326,7 @@ class MockMemoryOperator { uint64_t reclaim(MemoryPool* pool, uint64_t targetBytes) { VELOX_CHECK_GT(targetBytes, 0); uint64_t bytesReclaimed{0}; - std::vector allocationsToFree; + std::vector allocationsToFree; { std::lock_guard l(mu_); VELOX_CHECK_NOT_NULL(pool_); @@ -2831,6 +2830,60 @@ DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationTimeout) { ASSERT_EQ(task->capacity(), 0); } +DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, reclaimLockTimeout) { + const uint64_t memoryCapacity = 256 * MB; + const uint64_t arbitrationTimeoutMs = 1'000; + setupMemory( + memoryCapacity, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1.0, + nullptr, + false, + arbitrationTimeoutMs); + std::shared_ptr task = addTask(memoryCapacity); + ASSERT_EQ(task->capacity(), 0); + auto* op = task->addMemoryOp(true); + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::abort", + std::function( + ([&](const ArbitrationParticipant* /*unused*/) { + std::this_thread::sleep_for( + std::chrono::milliseconds(2 * arbitrationTimeoutMs)); // NOLINT + }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::reclaim", + std::function( + ([&](const ArbitrationParticipant* /*unused*/) { + // Timeout shall be enforced at lock level. We don't expect code to + // execute pass the lock in reclaim method. + FAIL(); + }))); + + auto abortThread = std::thread( + [&]() { arbitrator_->shrinkCapacity(memoryCapacity, false, true); }); + try { + op->allocate(memoryCapacity / 2); + } catch (const VeloxException& ex) { + ASSERT_EQ(ex.errorCode(), error_code::kMemArbitrationTimeout); + ASSERT_THAT( + ex.what(), + testing::HasSubstr("Memory arbitration timed out on memory pool")); + } + + abortThread.join(); +} + DEBUG_ONLY_TEST_F(MockSharedArbitrationTest, localArbitrationQueueTimeout) { uint64_t memoryCapacity = 256 * MB; setupMemory( @@ -4173,5 +4226,4 @@ TEST_F(MockSharedArbitrationTest, concurrentArbitrationWithTransientRoots) { } controlThread.join(); } -} // namespace } // namespace facebook::velox::memory diff --git a/velox/common/memory/tests/SharedArbitratorTestUtil.h b/velox/common/memory/tests/SharedArbitratorTestUtil.h index 536a36f33ea74..0ba88c6a8cc4f 100644 --- a/velox/common/memory/tests/SharedArbitratorTestUtil.h +++ b/velox/common/memory/tests/SharedArbitratorTestUtil.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once #include "velox/common/memory/ArbitrationParticipant.h" #include "velox/common/memory/SharedArbitrator.h" @@ -29,7 +30,7 @@ class SharedArbitratorTestHelper { } size_t numParticipants() { - std::lock_guard l(arbitrator_->stateLock_); + std::lock_guard l(arbitrator_->stateMutex_); return arbitrator_->participants_.size(); } @@ -38,12 +39,12 @@ class SharedArbitratorTestHelper { } size_t numGlobalArbitrationWaiters() const { - std::lock_guard l(arbitrator_->stateLock_); + std::lock_guard l(arbitrator_->stateMutex_); return arbitrator_->globalArbitrationWaiters_.size(); } bool globalArbitrationRunning() const { - std::lock_guard l(arbitrator_->stateLock_); + std::lock_guard l(arbitrator_->stateMutex_); return arbitrator_->globalArbitrationRunning_; } @@ -70,7 +71,7 @@ class SharedArbitratorTestHelper { } bool hasShutdown() const { - std::lock_guard l(arbitrator_->stateLock_); + std::lock_guard l(arbitrator_->stateMutex_); return arbitrator_->hasShutdownLocked(); } @@ -107,4 +108,42 @@ class ArbitrationParticipantTestHelper { private: ArbitrationParticipant* const participant_; }; + +struct ArbitrationTestStructs { + ArbitrationParticipant::Config config; + std::shared_ptr participant{nullptr}; + std::shared_ptr operation{nullptr}; + + static ArbitrationTestStructs createArbitrationTestStructs( + const std::shared_ptr& pool, + uint64_t initCapacity = 1024, + uint64_t minCapacity = 128, + uint64_t fastExponentialGrowthCapacityLimit = 0, + double slowCapacityGrowRatio = 0, + uint64_t minFreeCapacity = 0, + double minFreeCapacityRatio = 0, + uint64_t minReclaimBytes = 128, + uint64_t abortCapacityLimit = 512, + uint64_t requestBytes = 128, + uint64_t maxArbitrationTimeNs = 1'000'000'000'000UL /* 1'000s */) { + ArbitrationTestStructs ret{ + .config = ArbitrationParticipant::Config( + initCapacity, + minCapacity, + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio, + minFreeCapacity, + minFreeCapacityRatio, + minReclaimBytes, + abortCapacityLimit)}; + ret.participant = ArbitrationParticipant::create( + folly::Random::rand64(), pool, &ret.config); + ret.operation = std::make_shared( + ScopedArbitrationParticipant(ret.participant, pool), + requestBytes, + maxArbitrationTimeNs); + return ret; + } +}; + } // namespace facebook::velox::memory::test diff --git a/velox/dwio/dwrf/test/E2EWriterTest.cpp b/velox/dwio/dwrf/test/E2EWriterTest.cpp index 22d1c624c9464..37de542b675d3 100644 --- a/velox/dwio/dwrf/test/E2EWriterTest.cpp +++ b/velox/dwio/dwrf/test/E2EWriterTest.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/base/SpillConfig.h" #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/Options.h" #include "velox/dwio/common/Statistics.h" @@ -1734,7 +1735,11 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimOnWrite) { const auto oldReservedBytes = writerPool->reservedBytes(); const auto oldUsedBytes = writerPool->usedBytes(); { - memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + writerPool); + memory::ScopedMemoryArbitrationContext arbitrationCtx( + writerPool.get(), arbitrationStructs.operation.get()); writerPool->reclaim(1L << 30, 0, stats); } ASSERT_EQ(stats.numNonReclaimableAttempts, 0); @@ -1773,7 +1778,11 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimOnWrite) { writer->testingNonReclaimableSection() = false; stats.numNonReclaimableAttempts = 0; { - memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + writerPool); + memory::ScopedMemoryArbitrationContext arbitrationCtx( + writerPool.get(), arbitrationStructs.operation.get()); const auto reclaimedBytes = writerPool->reclaim(1L << 30, 0, stats); ASSERT_GT(reclaimedBytes, 0); } @@ -2115,7 +2124,11 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimThreshold) { *writerPool, reclaimableBytes)); ASSERT_GT(reclaimableBytes, 0); { - memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + writerPool); + memory::ScopedMemoryArbitrationContext arbitrationCtx( + writerPool.get(), arbitrationStructs.operation.get()); ASSERT_GT(writerPool->reclaim(1L << 30, 0, stats), 0); } ASSERT_GT(stats.reclaimExecTimeUs, 0); @@ -2125,7 +2138,11 @@ DEBUG_ONLY_TEST_F(E2EWriterTest, memoryReclaimThreshold) { *writerPool, reclaimableBytes)); ASSERT_EQ(reclaimableBytes, 0); { - memory::ScopedMemoryArbitrationContext arbitrationCtx(writerPool.get()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + writerPool); + memory::ScopedMemoryArbitrationContext arbitrationCtx( + writerPool.get(), arbitrationStructs.operation.get()); ASSERT_EQ(writerPool->reclaim(1L << 30, 0, stats), 0); } ASSERT_EQ(stats.numNonReclaimableAttempts, 0); diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index c48e01fb518b8..1138ef7406d74 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -22,6 +22,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Aggregate.h" @@ -2132,7 +2133,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringInputProcessing) { if (testData.expectedReclaimable) { { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), 0, @@ -2145,7 +2150,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringInputProcessing) { ASSERT_EQ(op->pool()->usedBytes(), 0); } else { { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); VELOX_ASSERT_THROW( op->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), @@ -2261,7 +2270,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringReserve) { const auto usedMemory = op->pool()->usedBytes(); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), 0, @@ -2506,7 +2519,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringOutputProcessing) { if (enableSpilling) { ASSERT_GT(reclaimableBytes, 0); const auto usedMemory = op->pool()->usedBytes(); - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), 0, @@ -2518,7 +2535,11 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringOutputProcessing) { reclaimerStats_.reset(); } else { ASSERT_EQ(reclaimableBytes, 0); - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); VELOX_ASSERT_THROW( op->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), @@ -3141,7 +3162,10 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimEmptyOutput) { { MemoryReclaimer::Stats stats; SuspendedSection suspendedSection(driver); - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = memory::test::ArbitrationTestStructs:: + createArbitrationTestStructs(op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); task->pool()->reclaim(kMaxBytes, 0, stats); ASSERT_EQ(stats.numNonReclaimableAttempts, 0); ASSERT_GT(stats.reclaimExecTimeUs, 0); diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 98c1c20796b41..cabe33d1fa878 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -20,6 +20,7 @@ #include "folly/experimental/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/HashBuild.h" @@ -5783,7 +5784,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringInputProcessing) { if (testData.expectedReclaimable) { { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), 0, @@ -5923,7 +5928,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringReserve) { ASSERT_GT(reclaimableBytes, 0); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), 0, @@ -6178,7 +6187,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringOutputProcessing) { ASSERT_GT(reclaimableBytes, 0); const auto usedMemoryBytes = op->pool()->usedBytes(); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), 0, @@ -6259,7 +6272,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { } auto* driver = op->testingOperatorCtx()->driver(); auto task = driver->task(); - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); SuspendedSection suspendedSection(driver); auto taskPauseWait = task->requestPause(); taskPauseWait.wait(); @@ -6326,7 +6343,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { const auto usedMemoryBytes = op->pool()->usedBytes(); reclaimerStats_.reset(); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(), 0, diff --git a/velox/exec/tests/MemoryReclaimerTest.cpp b/velox/exec/tests/MemoryReclaimerTest.cpp index aaa54fb8027c3..9880f3102f647 100644 --- a/velox/exec/tests/MemoryReclaimerTest.cpp +++ b/velox/exec/tests/MemoryReclaimerTest.cpp @@ -17,6 +17,7 @@ #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/memory/MemoryPool.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -264,7 +265,11 @@ TEST_F(MemoryReclaimerTest, parallelMemoryReclaimer) { static_cast(leafPools.back()->reclaimer())); } - ScopedMemoryArbitrationContext context(rootPool.get()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + rootPool); + memory::ScopedMemoryArbitrationContext context( + rootPool.get(), arbitrationStructs.operation.get()); memory::MemoryReclaimer::Stats stats; rootPool->reclaim(testData.bytesToReclaim, 0, stats); for (int i = 0; i < memoryReclaimers.size(); ++i) { diff --git a/velox/exec/tests/OrderByTest.cpp b/velox/exec/tests/OrderByTest.cpp index 0d7fa4475a2bc..a193f9df51bca 100644 --- a/velox/exec/tests/OrderByTest.cpp +++ b/velox/exec/tests/OrderByTest.cpp @@ -19,6 +19,7 @@ #include "folly/experimental/EventCount.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/common/file/FileSystems.h" +#include "velox/common/memory/tests/SharedArbitratorTestUtil.h" #include "velox/common/testutil/TestValue.h" #include "velox/core/QueryConfig.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" @@ -651,7 +652,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringInputProcessing) { if (testData.expectedReclaimable) { { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), 0, @@ -777,7 +782,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringReserve) { ASSERT_GT(reclaimableBytes, 0); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), 0, @@ -1028,7 +1037,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringOutputProcessing) { ASSERT_GT(reclaimableBytes, 0); reclaimerStats_.reset(); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); op->pool()->reclaim(reclaimableBytes, 0, reclaimerStats_); } ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); @@ -1037,7 +1050,11 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringOutputProcessing) { } else { ASSERT_EQ(reclaimableBytes, 0); { - memory::ScopedMemoryArbitrationContext ctx(op->pool()); + auto arbitrationStructs = + memory::test::ArbitrationTestStructs::createArbitrationTestStructs( + op->pool()->shared_from_this()); + memory::ScopedMemoryArbitrationContext ctx( + op->pool(), arbitrationStructs.operation.get()); VELOX_ASSERT_THROW( op->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_),