diff --git a/velox/exec/tests/ExchangeClientTest.cpp b/velox/exec/tests/ExchangeClientTest.cpp index 154439e699eb..97f88ee45e25 100644 --- a/velox/exec/tests/ExchangeClientTest.cpp +++ b/velox/exec/tests/ExchangeClientTest.cpp @@ -84,6 +84,11 @@ class ExchangeClientTest : public testing::Test, bool atEnd; ContinueFuture future; auto pages = client.next(1, &atEnd, &future); + if (pages.empty()) { + auto& exec = folly::QueuedImmediateExecutor::instance(); + std::move(future).via(&exec).wait(); + pages = client.next(1, &atEnd, &future); + } ASSERT_EQ(1, pages.size()); } } @@ -344,5 +349,47 @@ TEST_F(ExchangeClientTest, sourceTimeout) { ASSERT_TRUE(atEnd); } +TEST_F(ExchangeClientTest, timeoutDuringValueCallback) { + common::testutil::TestValue::enable(); + auto row = makeRowVector({makeFlatVector({1, 2, 3})}); + + auto plan = test::PlanBuilder() + .values({row}) + .partitionedOutput({"c0"}, 100) + .planNode(); + auto taskId = "local://t1"; + auto task = makeTask(taskId, plan); + + bufferManager_->initializeTask( + task, core::PartitionedOutputNode::Kind::kPartitioned, 100, 16); + + ExchangeClient client( + "t", 17, pool(), ExchangeClient::kDefaultMaxQueuedBytes); + client.addRemoteTaskId(taskId); + int32_t numTimeouts = 0; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::LocalExchangeSource::timeout", + std::function(([&](void* /*ignore*/) { ++numTimeouts; }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::test::LocalExchangeSource", + std::function(([&](void* /*pages*/) { + std::this_thread::sleep_for( + std::chrono::seconds(2 * ExchangeClient::kDefaultMaxWaitSeconds)); + }))); + + auto thread = std::thread([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + enqueue(taskId, 17, row); + }); + + fetchPages(client, 1); + thread.join(); + EXPECT_EQ(0, numTimeouts); + + task->requestCancel(); + bufferManager_->removeTask(taskId); +} + } // namespace } // namespace facebook::velox::exec diff --git a/velox/exec/tests/utils/LocalExchangeSource.cpp b/velox/exec/tests/utils/LocalExchangeSource.cpp index 3086e9152aca..33a29c8be08a 100644 --- a/velox/exec/tests/utils/LocalExchangeSource.cpp +++ b/velox/exec/tests/utils/LocalExchangeSource.cpp @@ -56,11 +56,24 @@ class LocalExchangeSource : public exec::ExchangeSource { VELOX_CHECK(requestPending_); auto requestedSequence = sequence_; auto self = shared_from_this(); + + // Have a flag shared between the data available and timeout callbacks. Only + // one of these must run but they could overlap at call time. + static std::mutex realizeMutex; + auto state = std::make_shared(State::kPending); + // Since this lambda may outlive 'this', we need to capture a // shared_ptr to the current object (self). - auto resultCallback = [self, requestedSequence, buffers, this]( + auto resultCallback = [self, requestedSequence, buffers, state, this]( std::vector> data, int64_t sequence) { + { + std::lock_guard l(realizeMutex); + if (*state != State::kPending) { + return; + } + *state = State::kResultReceived; + } if (requestedSequence > sequence) { VLOG(2) << "Receives earlier sequence than requested: task " << taskId_ << ", destination " << destination_ << ", requested " @@ -127,24 +140,55 @@ class LocalExchangeSource : public exec::ExchangeSource { if (!requestPromise.isFulfilled()) { requestPromise.setValue(Response{totalBytes, atEnd_}); } + { + std::lock_guard l(realizeMutex); + *state = State::kResultProcessed; + } }; // Call the callback in any case after timeout. auto& exec = folly::QueuedImmediateExecutor::instance(); + future = std::move(future).via(&exec).onTimeout( - std::chrono::seconds(maxWaitSeconds), [self, this] { - common::testutil::TestValue::adjust( - "facebook::velox::exec::test::LocalExchangeSource::timeout", - this); - VeloxPromise requestPromise; - { - std::lock_guard l(queue_->mutex()); - requestPending_ = false; - requestPromise = std::move(promise_); - } + std::chrono::seconds(maxWaitSeconds), [self, state, this] { + // The timeout callback detects if a result is being + // processed. If so, it waits for the result processing to be + // complete. It must not realize promises while a result is + // being processed. After the result is processed, returning a + // value should be no-op since the promise already has a value. + bool done = false; + bool timeout = false; + do { + { + std::lock_guard l(realizeMutex); + if (*state == State::kPending) { + *state = State::kTimeout; + timeout = true; + done = true; + } else if (*state == State::kResultReceived) { + done = true; + } + } + if (!done) { + // wait for the result callback to finish on another thread. Must + // not set the future until the other thread is finished. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } while (!done); Response response = {0, false}; - if (!requestPromise.isFulfilled()) { - requestPromise.setValue(response); + if (timeout) { + common::testutil::TestValue::adjust( + "facebook::velox::exec::test::LocalExchangeSource::timeout", + this); + VeloxPromise requestPromise; + { + std::lock_guard l(queue_->mutex()); + requestPending_ = false; + requestPromise = std::move(promise_); + } + if (!requestPromise.isFulfilled()) { + requestPromise.setValue(response); + } } return response; }); @@ -173,6 +217,12 @@ class LocalExchangeSource : public exec::ExchangeSource { } private: + // state for serializing concurrent result and timeout. If timeout + // happens when state is kResultReceived, it must wait until state + // is kResultProcessed. If result arrives when state != kPending, + // the result is ignored. + enum class State { kPending, kResultReceived, kResultProcessed, kTimeout }; + bool checkSetRequestPromise() { VeloxPromise promise; {