diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index 1d664d1a0ac1..cda854de711d 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -457,4 +457,18 @@ MemoryArbitrationContext* memoryArbitrationContext() { bool underMemoryArbitration() { return memoryArbitrationContext() != nullptr; } + +void testingRunArbitration(uint64_t targetBytes, MemoryManager* manager) { + if (manager == nullptr) { + manager = memory::memoryManager(); + } + manager->shrinkPools(targetBytes); +} + +void testingRunArbitration(MemoryPool* pool, uint64_t targetBytes) { + pool->enterArbitration(); + static_cast(pool)->testingManager()->shrinkPools( + targetBytes); + pool->leaveArbitration(); +} } // namespace facebook::velox::memory diff --git a/velox/common/memory/MemoryArbitrator.h b/velox/common/memory/MemoryArbitrator.h index 20371f6bdc26..0f675a6dec37 100644 --- a/velox/common/memory/MemoryArbitrator.h +++ b/velox/common/memory/MemoryArbitrator.h @@ -383,4 +383,18 @@ MemoryArbitrationContext* memoryArbitrationContext(); /// Returns true if the running thread is under memory arbitration or not. bool underMemoryArbitration(); + +/// The function triggers memory arbitration by shrinking memory pools from +/// 'manager' by invoking shrinkPools API. If 'manager' is not set, then it +/// shrinks from the process wide memory manager. If 'targetBytes' is zero, then +/// reclaims all the memory from 'manager' if possible. +class MemoryManager; +void testingRunArbitration( + uint64_t targetBytes = 0, + MemoryManager* manager = nullptr); + +/// The function triggers memory arbitration by shrinking memory pools from +/// 'manager' of 'pool' by invoking its shrinkPools API. If 'targetBytes' is +/// zero, then reclaims all the memory from 'manager' if possible. +void testingRunArbitration(MemoryPool* pool, uint64_t targetBytes = 0); } // namespace facebook::velox::memory diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index 841abcc5d9a8..e9a55ad15126 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -679,6 +679,10 @@ class MemoryPoolImpl : public MemoryPool { void testingSetCapacity(int64_t bytes); + MemoryManager* testingManager() const { + return manager_; + } + MemoryAllocator* testingAllocator() const { return allocator_; } diff --git a/velox/core/QueryCtx.h b/velox/core/QueryCtx.h index bab975301c4e..d27370bcb9ae 100644 --- a/velox/core/QueryCtx.h +++ b/velox/core/QueryCtx.h @@ -142,8 +142,10 @@ class QueryCtx { void initPool(const std::string& queryId) { if (pool_ == nullptr) { - pool_ = memory::deprecatedDefaultMemoryManager().addRootPool( - QueryCtx::generatePoolName(queryId)); + pool_ = memory::memoryManager()->addRootPool( + QueryCtx::generatePoolName(queryId), + memory::kMaxMemory, + memory::MemoryReclaimer::create()); } } diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 2e3965e17742..3730ce300373 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -452,6 +452,7 @@ bool HashBuild::ensureInputFits(RowVectorPtr& input) { bool HashBuild::reserveMemory(const RowVectorPtr& input) { VELOX_CHECK(spillEnabled()); + Operator::ReclaimableSectionGuard guard(this); numSpillRows_ = 0; numSpillBytes_ = 0; @@ -468,9 +469,10 @@ bool HashBuild::reserveMemory(const RowVectorPtr& input) { if (numRows != 0) { // Test-only spill path. if (testingTriggerSpill()) { - numSpillRows_ = std::max(1, numRows / 10); - numSpillBytes_ = numSpillRows_ * outOfLineBytesPerRow; - return false; + memory::testingRunArbitration(pool()); + // NOTE: the memory arbitration should have triggered spilling on this + // hash build operator so we return true to indicate have enough memory. + return true; } // We check usage from the parent pool to take peers' allocations into @@ -522,11 +524,8 @@ bool HashBuild::reserveMemory(const RowVectorPtr& input) { incrementBytes * 2, currentUsage * spillConfig_->spillableReservationGrowthPct / 100); - { - Operator::ReclaimableSectionGuard guard(this); - if (pool()->maybeReserve(targetIncrementBytes)) { - return true; - } + if (pool()->maybeReserve(targetIncrementBytes)) { + return true; } LOG(WARNING) << "Failed to reserve " << succinctBytes(targetIncrementBytes) diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 43ea05ecfdd1..794528d7b447 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -24,7 +24,6 @@ #include "velox/exec/HashBuild.h" #include "velox/exec/HashJoinBridge.h" #include "velox/exec/PlanNodeStats.h" -#include "velox/exec/TableScan.h" #include "velox/exec/tests/utils/ArbitratorTestUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/Cursor.h" @@ -734,6 +733,11 @@ class HashJoinBuilder { class HashJoinTest : public HiveConnectorTestBase { protected: + static void SetUpTestCase() { + FLAGS_velox_testing_enable_arbitration = true; + HiveConnectorTestBase::SetUpTestCase(); + } + HashJoinTest() : HashJoinTest(TestParam(1)) {} explicit HashJoinTest(const TestParam& param) diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.cpp b/velox/exec/tests/utils/ArbitratorTestUtil.cpp index e72b0743c778..58c5d0af4b25 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.cpp +++ b/velox/exec/tests/utils/ArbitratorTestUtil.cpp @@ -348,21 +348,4 @@ QueryTestResult runWriteTask( } return result; } - -void testingRunArbitration( - memory::MemoryPool* pool, - uint64_t targetBytes, - memory::MemoryManager* manager) { - if (manager == nullptr) { - manager = memory::memoryManager(); - } - if (pool != nullptr) { - pool->enterArbitration(); - manager->shrinkPools(targetBytes); - pool->leaveArbitration(); - } else { - manager->shrinkPools(targetBytes); - } -} - } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/ArbitratorTestUtil.h b/velox/exec/tests/utils/ArbitratorTestUtil.h index ddbddd232dc4..3b6610fd12f5 100644 --- a/velox/exec/tests/utils/ArbitratorTestUtil.h +++ b/velox/exec/tests/utils/ArbitratorTestUtil.h @@ -179,15 +179,4 @@ QueryTestResult runWriteTask( const std::string& kHiveConnectorId, bool enableSpilling, const RowVectorPtr& expectedResult = nullptr); - -/// The function triggers memory arbitration by shrinking memory pools from -/// 'manager' by invoking shrinkPools API. If 'manager' is not set, then it -/// shrinks from the process wide memory manager. If 'pool' is provided, the -/// function puts 'pool' in arbitration state before the arbitration to ease -/// test use. If 'targetBytes' is zero, then reclaims all the memory from -/// 'manager' if possible. -void testingRunArbitration( - memory::MemoryPool* pool = nullptr, - uint64_t targetBytes = 0, - memory::MemoryManager* manager = nullptr); } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/OperatorTestBase.cpp b/velox/exec/tests/utils/OperatorTestBase.cpp index 4871d58ca325..ac19629621de 100644 --- a/velox/exec/tests/utils/OperatorTestBase.cpp +++ b/velox/exec/tests/utils/OperatorTestBase.cpp @@ -34,6 +34,10 @@ DECLARE_bool(velox_memory_leak_check_enabled); DECLARE_bool(velox_enable_memory_usage_track_in_default_memory_pool); +DEFINE_bool( + velox_testing_enable_arbitration, + false, + "Enable to turn on arbitration for tests by default"); using namespace facebook::velox::common::testutil; @@ -58,6 +62,12 @@ void OperatorTestBase::SetUpTestCase() { exec::SharedArbitrator::registerFactory(); memory::MemoryManagerOptions options; options.allocatorCapacity = 8L << 30; + if (FLAGS_velox_testing_enable_arbitration) { + options.arbitratorCapacity = 6L << 30; + options.arbitratorKind = "SHARED"; + options.checkUsageLeak = true; + options.arbitrationStateCheckCb = memoryArbitrationStateCheck; + } memory::MemoryManager::testingSetInstance(options); asyncDataCache_ = cache::AsyncDataCache::create(memoryManager()->allocator()); cache::AsyncDataCache::setInstance(asyncDataCache_.get()); diff --git a/velox/exec/tests/utils/OperatorTestBase.h b/velox/exec/tests/utils/OperatorTestBase.h index 54eb0d6b97eb..a10b316c173f 100644 --- a/velox/exec/tests/utils/OperatorTestBase.h +++ b/velox/exec/tests/utils/OperatorTestBase.h @@ -28,6 +28,8 @@ #include "velox/vector/tests/utils/VectorMaker.h" #include "velox/vector/tests/utils/VectorTestBase.h" +DECLARE_bool(velox_testing_enable_arbitration); + namespace facebook::velox::exec::test { class OperatorTestBase : public testing::Test, public velox::test::VectorTestBase { diff --git a/velox/expression/tests/FuzzerRunner.cpp b/velox/expression/tests/FuzzerRunner.cpp index aef6b786f8da..76d9aff62e46 100644 --- a/velox/expression/tests/FuzzerRunner.cpp +++ b/velox/expression/tests/FuzzerRunner.cpp @@ -210,6 +210,7 @@ int FuzzerRunner::run( void FuzzerRunner::runFromGtest( size_t seed, const std::unordered_set& skipFunctions) { + memory::MemoryManager::testingSetInstance({}); auto signatures = facebook::velox::getFunctionSignatures(); ExpressionFuzzerVerifier( signatures, seed, getExpressionFuzzerVerifierOptions(skipFunctions))