Skip to content

Commit

Permalink
Fix race between task resume and task terminate
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoxmeng committed Dec 7, 2023
1 parent 975ca3a commit 4c5f917
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 12 deletions.
6 changes: 3 additions & 3 deletions velox/exec/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,9 +824,9 @@ void Task::resume(std::shared_ptr<Task> self) {
// Setting pause requested must be atomic with the resuming so that
// suspended sections do not go back on thread during resume.
self->pauseRequested_ = false;
if (self->exception_ == nullptr) {
if (self->isRunningLocked()) {
for (auto& driver : self->drivers_) {
if (driver) {
if (driver != nullptr) {
if (driver->state().isSuspended) {
// The Driver will come on thread in its own time as long as
// the cancel flag is reset. This check needs to be inside 'mutex_'.
Expand Down Expand Up @@ -2285,7 +2285,7 @@ StopReason Task::enter(ThreadState& state, uint64_t nowMicros) {
if (state.isOnThread()) {
return StopReason::kAlreadyOnThread;
}
auto reason = shouldStopLocked();
const auto reason = shouldStopLocked();
if (reason == StopReason::kTerminate) {
state.isTerminated = true;
}
Expand Down
5 changes: 5 additions & 0 deletions velox/exec/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ class Task : public std::enable_shared_from_this<Task> {
/// Invoked to run provided 'callback' on each alive driver of the task.
void testingVisitDrivers(const std::function<void(Driver*)>& callback);

/// Invoked to finish the task for test purpose.
void testingFinish() {
terminate(TaskState::kFinished).wait();
}

private:
Task(
const std::string& taskId,
Expand Down
10 changes: 6 additions & 4 deletions velox/exec/tests/SharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3794,7 +3794,7 @@ DEBUG_ONLY_TEST_F(SharedArbitrationTest, joinBuildSpillError) {
ASSERT_EQ(arbitrator_->stats().numReserves, numAddedPools_);
}

TEST_F(SharedArbitrationTest, DISABLED_concurrentArbitration) {
TEST_F(SharedArbitrationTest, concurrentArbitration) {
// Tries to replicate an actual workload by concurrently running multiple
// query shapes that support spilling (and hence can be forced to abort or
// spill by the arbitrator). Also adds an element of randomness by randomly
Expand Down Expand Up @@ -3831,13 +3831,13 @@ TEST_F(SharedArbitrationTest, DISABLED_concurrentArbitration) {
succinctBytes(totalCapacity),
succinctBytes(queryCapacity));
}
} testSettings[3] = {
} testSettings[] = {
{16 * MB, 128 * MB}, {128 * MB, 16 * MB}, {128 * MB, 128 * MB}};

for (const auto& testData : testSettings) {
SCOPED_TRACE(testData.debugString());
auto totalCapacity = testData.totalCapacity;
auto queryCapacity = testData.queryCapacity;
const auto totalCapacity = testData.totalCapacity;
const auto queryCapacity = testData.queryCapacity;
setupMemory(totalCapacity);

std::mutex mutex;
Expand Down Expand Up @@ -3895,6 +3895,8 @@ TEST_F(SharedArbitrationTest, DISABLED_concurrentArbitration) {
for (auto& queryThread : queryThreads) {
queryThread.join();
}
zombieTasks.clear();
waitForAllTasksToBeDeleted();
ASSERT_GT(arbitrator_->stats().numRequests, 0);
}
}
Expand Down
50 changes: 50 additions & 0 deletions velox/exec/tests/TaskTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "velox/exec/OutputBufferManager.h"
#include "velox/exec/PlanNodeStats.h"
#include "velox/exec/Values.h"
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/exec/tests/utils/Cursor.h"
#include "velox/exec/tests/utils/HiveConnectorTestBase.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
Expand Down Expand Up @@ -1414,4 +1415,53 @@ TEST_F(TaskTest, spillDirNotCreated) {
auto fs = filesystems::getFileSystem(tmpDirectoryPath, nullptr);
EXPECT_FALSE(fs->exists(tmpDirectoryPath));
}

DEBUG_ONLY_TEST_F(TaskTest, resumeAfterTaskFinish) {
auto probeVector = makeRowVector(
{"t_c0"}, {makeFlatVector<int32_t>(10, [](auto row) { return row; })});
auto buildVector = makeRowVector(
{"u_c0"}, {makeFlatVector<int32_t>(10, [](auto row) { return row; })});
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.values({probeVector})
.hashJoin(
{"t_c0"},
{"u_c0"},
PlanBuilder(planNodeIdGenerator).values({buildVector}).planNode(),
"",
{"t_c0", "u_c0"})
.planFragment();

std::atomic<bool> valuesWaitFlag{true};
folly::EventCount valuesWait;
SCOPED_TESTVALUE_SET(
"facebook::velox::exec::Values::getOutput",
std::function<void(const velox::exec::Values*)>(
([&](const velox::exec::Values* values) {
valuesWait.await([&]() { return !valuesWaitFlag.load(); });
})));

auto task = Task::create(
"task",
std::move(plan),
0,
std::make_shared<core::QueryCtx>(driverExecutor_.get()));
task->start(4, 1);

// Request pause and then unblock operators to proceed.
auto pauseWait = task->requestPause();
valuesWaitFlag = false;
valuesWait.notifyAll();
// Wait for task pause to complete.
pauseWait.wait();
// Finish the task and for a hash join, the probe operator should still be in
// waiting for build stage.
task->testingFinish();
// Resume the task and expect all drivers to close.
Task::resume(task);
ASSERT_TRUE(waitForTaskCompletion(task.get()));
task.reset();
waitForAllTasksToBeDeleted();
}
} // namespace facebook::velox::exec::test
7 changes: 3 additions & 4 deletions velox/exec/tests/utils/Cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ bool waitForTaskDriversToFinish(exec::Task* task, uint64_t maxWaitMicros) {
}

if (task->numFinishedDrivers() != task->numTotalDrivers()) {
LOG(ERROR)
<< "Timed out waiting for all task drivers to finish. Finished drivers: "
<< task->numFinishedDrivers()
<< ". Total drivers: " << task->numTotalDrivers();
LOG(ERROR) << "Timed out waiting for all drivers of task " << task->taskId()
<< " to finish. Finished drivers: " << task->numFinishedDrivers()
<< ". Total drivers: " << task->numTotalDrivers();
}

return task->numFinishedDrivers() == task->numTotalDrivers();
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/tests/utils/QueryAssertions.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class DuckDbQueryRunner {
std::pair<std::unique_ptr<TaskCursor>, std::vector<RowVectorPtr>> readCursor(
const CursorParameters& params,
std::function<void(exec::Task*)> addSplits,
uint64_t maxWaitMicros = 1'000'000);
uint64_t maxWaitMicros = 5'000'000);

/// The Task can return results before the Driver is finished executing.
/// Wait upto maxWaitMicros for the Task to finish as 'expectedState' before
Expand Down

0 comments on commit 4c5f917

Please sign in to comment.