diff --git a/Common/interface/AsyncInitializer.hpp b/Common/interface/AsyncInitializer.hpp index 70b2f66f0..cee124b30 100644 --- a/Common/interface/AsyncInitializer.hpp +++ b/Common/interface/AsyncInitializer.hpp @@ -106,7 +106,13 @@ class AsyncInitializer { VERIFY_EXPR(pThreadPool != nullptr); return std::unique_ptr{ - new AsyncInitializer{EnqueueAsyncWork(pThreadPool, ppPrerequisites, NumPrerequisites, std::forward(Handler))}, + new AsyncInitializer{ + EnqueueAsyncWork(pThreadPool, ppPrerequisites, NumPrerequisites, + [Handler = std::forward(Handler)](Uint32 ThreadId) mutable { + Handler(ThreadId); + return ASYNC_TASK_STATUS_COMPLETE; + }), + }, }; } diff --git a/Common/interface/ThreadPool.hpp b/Common/interface/ThreadPool.hpp index 0bba59379..ae69cd507 100644 --- a/Common/interface/ThreadPool.hpp +++ b/Common/interface/ThreadPool.hpp @@ -95,7 +95,8 @@ class AsyncTaskBase : public ObjectBase break; case ASYNC_TASK_STATUS_NOT_STARTED: - DEV_ERROR("NOT_STARTED is only allowed as initial task status."); + DEV_CHECK_ERR(m_TaskStatus == ASYNC_TASK_STATUS_RUNNING, + "A task should only be moved to NOT_STARTED state from RUNNING state."); break; case ASYNC_TASK_STATUS_RUNNING: @@ -184,8 +185,8 @@ RefCntAutoPtr EnqueueAsyncWork(IThreadPool* pThreadPool, virtual void DILIGENT_CALL_TYPE Run(Uint32 ThreadId) override final { - m_Handler(ThreadId); - SetStatus(ASYNC_TASK_STATUS_COMPLETE); + ASYNC_TASK_STATUS TaskStatus = m_Handler(ThreadId); + SetStatus(TaskStatus); } private: diff --git a/Tests/DiligentCoreTest/src/Common/ThreadPoolTest.cpp b/Tests/DiligentCoreTest/src/Common/ThreadPoolTest.cpp index 9dc711db3..b820cfab7 100644 --- a/Tests/DiligentCoreTest/src/Common/ThreadPoolTest.cpp +++ b/Tests/DiligentCoreTest/src/Common/ThreadPoolTest.cpp @@ -75,6 +75,8 @@ TEST(Common_ThreadPool, EnqueueTask) f = std::sin(f + 1.f); Results[i].store(f); WorkComplete[i].store(true); + + return ASYNC_TASK_STATUS_COMPLETE; }); } @@ -133,6 +135,8 @@ TEST(Common_ThreadPool, ProcessTask) f = std::sin(f + 1.f); Results[i].store(f); WorkComplete[i].store(true); + + return ASYNC_TASK_STATUS_COMPLETE; }); } @@ -326,6 +330,7 @@ TEST(Common_ThreadPool, Priorities) [&CompletionOrder, i](Uint32 ThreadId) // { CompletionOrder.push_back(i); + return ASYNC_TASK_STATUS_COMPLETE; }); } @@ -396,6 +401,8 @@ TEST(Common_ThreadPool, Prerequisites) } if (CorrectOrder) NumTasksCorrectlyOrdered.fetch_add(1); + + return ASYNC_TASK_STATUS_COMPLETE; }, static_cast(task) // Inverse priority so that the thread pool fixes it ); @@ -407,4 +414,32 @@ TEST(Common_ThreadPool, Prerequisites) } } + +TEST(Common_ThreadPool, ReRunTasks) +{ + auto pThreadPool = CreateThreadPool(ThreadPoolCreateInfo{4}); + ASSERT_NE(pThreadPool, nullptr); + + constexpr Uint32 NumTasks = 32; + std::vector> ReRunCounters(NumTasks); + + for (int i = 0; i < static_cast(ReRunCounters.size()); ++i) + ReRunCounters[i] = 32 + i; + + for (Uint32 task = 0; task < NumTasks; ++task) + { + EnqueueAsyncWork( + pThreadPool, + [task, &ReRunCounters](Uint32 ThreadId) // + { + int ReRunCounter = ReRunCounters[task].fetch_add(-1) - 1; + return ReRunCounter > 0 ? ASYNC_TASK_STATUS_NOT_STARTED : ASYNC_TASK_STATUS_COMPLETE; + }); + } + + pThreadPool->WaitForAllTasks(); + for (size_t i = 0; i < ReRunCounters.size(); ++i) + EXPECT_EQ(ReRunCounters[i], 0) << i; +} + } // namespace