diff --git a/sycl/source/detail/global_handler.cpp b/sycl/source/detail/global_handler.cpp index 51a77ab85a4d..4d1040f21bdf 100644 --- a/sycl/source/detail/global_handler.cpp +++ b/sycl/source/detail/global_handler.cpp @@ -38,7 +38,7 @@ using LockGuard = std::lock_guard; SpinLock GlobalHandler::MSyclGlobalHandlerProtector{}; // forward decl -void shutdown_early(); +void shutdown_early(bool); void shutdown_late(); #ifdef _WIN32 BOOL isLinkedStatically(); @@ -283,12 +283,12 @@ struct StaticVarShutdownHandler { // If statically linked, DllMain will not be called. So we do its work // here. if (isLinkedStatically()) { - shutdown_early(); + shutdown_early(true); } shutdown_late(); #else - shutdown_early(); + shutdown_early(true); #endif } catch (std::exception &e) { __SYCL_REPORT_EXCEPTION_TO_STREAM( @@ -343,7 +343,10 @@ void GlobalHandler::drainThreadPool() { MHostTaskThreadPool.Inst->drain(); } -void shutdown_early() { +// Note: this function can be called on Windows twice: +// 1) when library is unloaded via FreeLibrary +// 2) when process is being terminated +void shutdown_early(bool CanJoinThreads = true) { const LockGuard Lock{GlobalHandler::MSyclGlobalHandlerProtector}; GlobalHandler *&Handler = GlobalHandler::getInstancePtr(); if (!Handler) @@ -362,8 +365,10 @@ void shutdown_early() { // upon its release Handler->prepareSchedulerToRelease(true); - if (Handler->MHostTaskThreadPool.Inst) - Handler->MHostTaskThreadPool.Inst->finishAndWait(); + if (Handler->MHostTaskThreadPool.Inst) { + Handler->MHostTaskThreadPool.Inst->finishAndWait(CanJoinThreads); + Handler->MHostTaskThreadPool.Inst.reset(nullptr); + } // This releases OUR reference to the default context, but // other may yet have refs @@ -424,7 +429,14 @@ extern "C" __SYCL_EXPORT BOOL WINAPI DllMain(HINSTANCE hinstDLL, std::cout << "---> DLL_PROCESS_DETACH syclx.dll\n" << std::endl; try { - shutdown_early(); + // WA for threads handling. We must call join() or detach() on host task + // execution thread to avoid UB. lpReserved == NULL if library is unloaded + // via FreeLibrary. In this case we can't join threads within DllMain call + // due to global loader lock and DLL_THREAD_DETACH signalling. lpReserved + // != NULL if library is unloaded during process termination. In this case + // Windows terminates threads but leave them in signalled state, prevents + // DLL_THREAD_DETACH notification and we can call join() as NOP. + shutdown_early(lpReserved != NULL); } catch (std::exception &e) { __SYCL_REPORT_EXCEPTION_TO_STREAM("exception in DLL_PROCESS_DETACH", e); return FALSE; diff --git a/sycl/source/detail/global_handler.hpp b/sycl/source/detail/global_handler.hpp index fd3dd7e1fc0a..d0a9ae0fa4e6 100644 --- a/sycl/source/detail/global_handler.hpp +++ b/sycl/source/detail/global_handler.hpp @@ -97,7 +97,7 @@ class GlobalHandler { bool OkToDefer = true; - friend void shutdown_early(); + friend void shutdown_early(bool); friend void shutdown_late(); friend class ObjectUsageCounter; static GlobalHandler *&getInstancePtr(); diff --git a/sycl/source/detail/thread_pool.hpp b/sycl/source/detail/thread_pool.hpp index e9d441d6d27d..d24e48aa6bc6 100644 --- a/sycl/source/detail/thread_pool.hpp +++ b/sycl/source/detail/thread_pool.hpp @@ -32,6 +32,30 @@ class ThreadPool { bool MStop = false; std::atomic_uint MJobsInPool; +#ifdef _WIN32 + class ThreadExitTracker { + public: + void wait(size_t ThreadCount) { + std::unique_lock lk(MWorkerExitMutex); + MWorkerExitCV.wait( + lk, [&ThreadCount, this] { return MWorkerExitCount == ThreadCount; }); + } + + void signalAboutExit() { + { + std::lock_guard lk(MWorkerExitMutex); + MWorkerExitCount++; + } + MWorkerExitCV.notify_one(); + } + + private: + std::mutex MWorkerExitMutex; + std::condition_variable MWorkerExitCV; + size_t MWorkerExitCount{}; + } WinThreadExitTracker; +#endif + void worker() { GlobalHandler::instance().registerSchedulerUsage(/*ModifyCounter*/ false); std::unique_lock Lock(MJobQueueMutex); @@ -39,8 +63,12 @@ class ThreadPool { MDoSmthOrStop.wait(Lock, [this]() { return !MJobQueue.empty() || MStop; }); - if (MStop) - break; + if (MStop) { +#ifdef _WIN32 + WinThreadExitTracker.signalAboutExit(); +#endif + return; + } std::function Job = std::move(MJobQueue.front()); MJobQueue.pop(); @@ -76,14 +104,14 @@ class ThreadPool { ~ThreadPool() { try { #ifndef _WIN32 - finishAndWait(); + finishAndWait(true); #endif } catch (std::exception &e) { __SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ThreadPool", e); } } - void finishAndWait() { + void finishAndWait(bool CanJoinThreads) { { std::lock_guard Lock(MJobQueueMutex); MStop = true; @@ -91,6 +119,18 @@ class ThreadPool { MDoSmthOrStop.notify_all(); +#ifdef _WIN32 + if (!CanJoinThreads) { + WinThreadExitTracker.wait(MThreadCount); + for (std::thread &Thread : MLaunchedThreads) + Thread.detach(); + return; + } +#else + // We always can join on Linux. + std::ignore = CanJoinThreads; +#endif + for (std::thread &Thread : MLaunchedThreads) if (Thread.joinable()) Thread.join();