diff --git a/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp b/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp index f04ed4e65245..804a26ebed13 100644 --- a/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp +++ b/velox/exec/fuzzer/MemoryArbitrationFuzzer.cpp @@ -15,7 +15,6 @@ */ #include "velox/exec/fuzzer/MemoryArbitrationFuzzer.h" - #include #include "velox/common/file/FileSystems.h" @@ -71,11 +70,6 @@ DEFINE_int32( DEFINE_int64(arbitrator_capacity, 256L << 20, "Arbitrator capacity in bytes."); -DEFINE_int32( - abort_injection_pct, - 5, - "The percentage chance of triggering task abort"); - DEFINE_int32( global_arbitration_pct, 5, @@ -94,6 +88,12 @@ DEFINE_int32( "filesystem. This is only applicable when 'spill_faulty_fs_ratio' is " "larger than 0"); +DEFINE_int32( + task_abort_interval_ms, + 1000, + "After each specified number of milliseconds, abort a random task." + "If given 0, no task will be aborted."); + using namespace facebook::velox::tests::utils; namespace facebook::velox::exec::test { @@ -697,12 +697,12 @@ MemoryArbitrationFuzzer::orderByPlans(const std::string& tableDir) { } struct ThreadLocalStats { - uint64_t taskAbortCount{0}; uint64_t spillFsFaultCount{0}; }; // Stats that keeps track of per thread execution status in verify() thread_local ThreadLocalStats threadLocalStats; +std::atomic_uint32_t taskAbortRequestCount{0}; std::shared_ptr MemoryArbitrationFuzzer::maybeGenerateFaultySpillDirectory() { @@ -745,12 +745,6 @@ void MemoryArbitrationFuzzer::verify() { auto spillDirectory = maybeGenerateFaultySpillDirectory(); const auto tableScanDir = exec::test::TempDirectoryPath::create(false); - // Set a percentage chance for the task to be externally aborted. - TestScopedAbortInjection scopedAbortInjection( - FLAGS_abort_injection_pct, - std::numeric_limits::max(), - [](Task* /* unused */) { ++threadLocalStats.taskAbortCount; }); - std::vector plans; for (const auto& plan : hashJoinPlans(tableScanDir->getPath())) { plans.push_back(plan); @@ -782,8 +776,8 @@ void MemoryArbitrationFuzzer::verify() { queryThreads.emplace_back([&, spillDirectory, i, seed]() { FuzzerGenerator rng(seed); while (!stop) { - const auto prevAbortCount = threadLocalStats.taskAbortCount; const auto prevSpillFsFaultCount = threadLocalStats.spillFsFaultCount; + const auto prevTaskAbortRequestCount = taskAbortRequestCount.load(); try { const auto queryCtx = newQueryCtx( memory::memoryManager(), @@ -816,16 +810,16 @@ void MemoryArbitrationFuzzer::verify() { } else if (e.errorCode() == error_code::kMemAborted.c_str()) { ++lockedStats->abortCount; } else if (e.errorCode() == error_code::kInvalidState.c_str()) { - const auto injectedAbort = - threadLocalStats.taskAbortCount > prevAbortCount; const auto injectedSpillFsFault = threadLocalStats.spillFsFaultCount > prevSpillFsFaultCount; - VELOX_CHECK(injectedAbort || injectedSpillFsFault); - if (injectedAbort && !injectedSpillFsFault) { + const auto injectedTaskAbortRequest = + taskAbortRequestCount > prevTaskAbortRequestCount; + VELOX_CHECK(injectedSpillFsFault || injectedTaskAbortRequest); + if (injectedTaskAbortRequest && !injectedSpillFsFault) { VELOX_CHECK( e.message().find("Aborted for external error") != std::string::npos); - } else if (!injectedAbort && injectedSpillFsFault) { + } else if (!injectedTaskAbortRequest && injectedSpillFsFault) { VELOX_CHECK( e.message().find("Fault file injection on") != std::string::npos); @@ -855,6 +849,28 @@ void MemoryArbitrationFuzzer::verify() { } }); + // Create a thread that randomly abort one worker thread + // every task_abort_interval_ms milliseconds. + std::thread abortControlThread([&]() { + if (FLAGS_task_abort_interval_ms == 0) { + return; + } + while (!stop) { + try { + std::this_thread::sleep_for( + std::chrono::milliseconds(FLAGS_task_abort_interval_ms)); + auto tasksList = Task::getRunningTasks(); + auto index = getRandomIndex(rng_, tasksList.size() - 1); + ++taskAbortRequestCount; + tasksList[index]->requestAbort(); + } catch (const VeloxException& e) { + LOG(ERROR) << "Unexpected exception in abortControlScheduler:\n" + << e.what(); + std::rethrow_exception(std::current_exception()); + } + } + }); + std::this_thread::sleep_for( std::chrono::seconds(FLAGS_iteration_duration_sec)); stop = true; @@ -863,6 +879,7 @@ void MemoryArbitrationFuzzer::verify() { queryThread.join(); } globalShrinkThread.join(); + abortControlThread.join(); } void MemoryArbitrationFuzzer::go() {