From 08dedbfa517fd23e1ba492fa2b4630b38a6d177f Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Thu, 24 Oct 2024 14:26:05 +0100 Subject: [PATCH 1/7] Replaced type of cublas_handle map to CUcontext to remove UR dependency --- .../backends/cublas/cublas_scope_handle.cpp | 22 +++++-------------- .../backends/cublas/cublas_scope_handle.hpp | 18 +-------------- src/blas/backends/cublas/cublas_task.hpp | 12 ---------- 3 files changed, 7 insertions(+), 45 deletions(-) diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 8bb1145fa..54b8e396c 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -35,13 +35,8 @@ namespace cublas { * takes place if no other element in the container has a key equivalent to * the one being emplaced (keys in a map container are unique). */ -#ifdef ONEMKL_PI_INTERFACE_REMOVED -thread_local cublas_handle CublasScopedContextHandler::handle_helper = - cublas_handle{}; -#else -thread_local cublas_handle CublasScopedContextHandler::handle_helper = - cublas_handle{}; -#endif +thread_local cublas_handle CublasScopedContextHandler::handle_helper = + cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) : ih(ih), @@ -95,16 +90,11 @@ void ContextCallback(void* userData) { cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { auto cudaDevice = ih.get_native_device(); CUresult cuErr; - CUcontext desired; - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); -#ifdef ONEMKL_PI_INTERFACE_REMOVED - auto piPlacedContext_ = reinterpret_cast(desired); -#else - auto piPlacedContext_ = reinterpret_cast(desired); -#endif + CUcontext desiredCtx; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desiredCtx, cudaDevice); CUstream streamId = get_stream(queue); cublasStatus_t err; - auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_); + auto it = handle_helper.cublas_handle_mapper_.find(desiredCtx); if (it != handle_helper.cublas_handle_mapper_.end()) { if (it->second == nullptr) { handle_helper.cublas_handle_mapper_.erase(it); @@ -131,7 +121,7 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(piPlacedContext_, new std::atomic(handle))); + std::make_pair(desiredCtx, new std::atomic(handle))); sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, insert_iter.first->second); diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index d17909cfb..9b4d1c4c5 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -33,18 +33,6 @@ #include #endif -// After Plugin Interface removal in DPC++ ur.hpp is the new include -#if __has_include() -#include -#ifndef ONEMKL_PI_INTERFACE_REMOVED -#define ONEMKL_PI_INTERFACE_REMOVED -#endif -#elif __has_include() -#include -#else -#include -#endif - #include #include #include @@ -88,11 +76,7 @@ class CublasScopedContextHandler { sycl::context* placedContext_; bool needToRecover_; sycl::interop_handle& ih; -#ifdef ONEMKL_PI_INTERFACE_REMOVED - static thread_local cublas_handle handle_helper; -#else - static thread_local cublas_handle handle_helper; -#endif + static thread_local cublas_handle handle_helper; CUstream get_stream(const sycl::queue& queue); sycl::context get_context(const sycl::queue& queue); diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index 08d5cf70e..f4b530ddd 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -35,18 +35,6 @@ #else #include "cublas_scope_handle_hipsycl.hpp" -// After Plugin Interface removal in DPC++ ur.hpp is the new include -#if __has_include() -#include -#ifndef ONEMKL_PI_INTERFACE_REMOVED -#define ONEMKL_PI_INTERFACE_REMOVED -#endif -#elif __has_include() -#include -#else -#include -#endif - namespace sycl { using interop_handler = sycl::interop_handle; } From 8472bae01a53e316db62d9c535ad5c0e988fcd1e Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Thu, 24 Oct 2024 14:30:04 +0100 Subject: [PATCH 2/7] Removed checking if current Ctx is not Primary. It's always primary Ctx --- .../backends/cublas/cublas_scope_handle.cpp | 19 +------------------ .../backends/cublas/cublas_scope_handle.hpp | 1 - 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 54b8e396c..f50dbced8 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -39,32 +39,15 @@ thread_local cublas_handle CublasScopedContextHandler::handle_helper cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) - : ih(ih), - needToRecover_(false) { + : ih(ih) { placedContext_ = new sycl::context(queue.get_context()); auto cudaDevice = ih.get_native_device(); CUresult err; CUcontext desired; - CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); - if (original_ != desired) { - // Sets the desired context as the active one for the thread - CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); - // No context is installed and the suggested context is primary - // This is the most common case. We can activate the context in the - // thread and leave it there until all the PI context referring to the - // same underlying CUDA primary context are destroyed. This emulates - // the behaviour of the CUDA runtime api, and avoids costly context - // switches. No action is required on this side of the if. - needToRecover_ = !(original_ == nullptr); - } } CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) { - if (needToRecover_) { - CUresult err; - CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_); - } delete placedContext_; } diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index 9b4d1c4c5..31c1fa4c4 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -74,7 +74,6 @@ the handle must be destroyed when the context goes out of scope. This will bind class CublasScopedContextHandler { CUcontext original_; sycl::context* placedContext_; - bool needToRecover_; sycl::interop_handle& ih; static thread_local cublas_handle handle_helper; CUstream get_stream(const sycl::queue& queue); From 000ebf86ea31279d659fb7dc1a585ae312f6e8b9 Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Thu, 24 Oct 2024 14:39:34 +0100 Subject: [PATCH 3/7] Removed sycl::context* member and call to ContextSetExtendedDeleter. We don't need to do any early cleanup upon sycl context destruction. --- .../backends/cublas/cublas_scope_handle.cpp | 34 +------------------ .../backends/cublas/cublas_scope_handle.hpp | 2 -- 2 files changed, 1 insertion(+), 35 deletions(-) diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index f50dbced8..9913c8bce 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -39,36 +39,7 @@ thread_local cublas_handle CublasScopedContextHandler::handle_helper cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) - : ih(ih) { - placedContext_ = new sycl::context(queue.get_context()); - auto cudaDevice = ih.get_native_device(); - CUresult err; - CUcontext desired; - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); -} - -CublasScopedContextHandler::~CublasScopedContextHandler() noexcept(false) { - delete placedContext_; -} - -void ContextCallback(void* userData) { - auto* ptr = static_cast*>(userData); - if (!ptr) { - return; - } - auto handle = ptr->exchange(nullptr); - if (handle != nullptr) { - cublasStatus_t err1; - CUBLAS_ERROR_FUNC(cublasDestroy, err1, handle); - handle = nullptr; - } - else { - // if the handle is nullptr it means the handle was already destroyed by - // the cublas_handle destructor and we're free to delete the atomic - // object. - delete ptr; - } -} + : ih(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { auto cudaDevice = ih.get_native_device(); @@ -106,9 +77,6 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) auto insert_iter = handle_helper.cublas_handle_mapper_.insert( std::make_pair(desiredCtx, new std::atomic(handle))); - sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, - insert_iter.first->second); - return handle; } diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index 31c1fa4c4..da088958c 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -73,7 +73,6 @@ the handle must be destroyed when the context goes out of scope. This will bind class CublasScopedContextHandler { CUcontext original_; - sycl::context* placedContext_; sycl::interop_handle& ih; static thread_local cublas_handle handle_helper; CUstream get_stream(const sycl::queue& queue); @@ -82,7 +81,6 @@ class CublasScopedContextHandler { public: CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih); - ~CublasScopedContextHandler() noexcept(false); /** * @brief get_handle: creates the handle by implicitly impose the advice * given by nvidia for creating a cublas_handle. (e.g. one cuStream per device From 7b43b95df8ffaedc438e196eeb19805d7ea54432 Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Thu, 24 Oct 2024 15:10:48 +0100 Subject: [PATCH 4/7] Remove unnecessary includes --- src/blas/backends/cublas/cublas_scope_handle.cpp | 5 ----- src/blas/backends/cublas/cublas_scope_handle.hpp | 9 --------- 2 files changed, 14 deletions(-) diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 9913c8bce..537040634 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -17,11 +17,6 @@ * **************************************************************************/ #include "cublas_scope_handle.hpp" -#if __has_include() -#include -#else -#include -#endif namespace oneapi { namespace mkl { diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index da088958c..ab8ffff19 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -23,15 +23,6 @@ #else #include #endif -#if __has_include() -#if __SYCL_COMPILER_VERSION <= 20220930 -#include -#endif -#include -#else -#include -#include -#endif #include #include From 1961e14a9302233bf8c5a966ff75d9fdaa3dab3d Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Thu, 24 Oct 2024 21:52:33 +0100 Subject: [PATCH 5/7] Changed handle_helper to have unordered_map of CUdevice(s) -> cublasHandle_t --- .../backends/cublas/cublas_scope_handle.cpp | 40 ++++++------------- .../backends/cublas/cublas_scope_handle.hpp | 3 +- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 537040634..164c3c3aa 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -30,47 +30,33 @@ namespace cublas { * takes place if no other element in the container has a key equivalent to * the one being emplaced (keys in a map container are unique). */ -thread_local cublas_handle CublasScopedContextHandler::handle_helper = - cublas_handle{}; +thread_local cublas_handle CublasScopedContextHandler::handle_helper = + cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) : ih(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { - auto cudaDevice = ih.get_native_device(); - CUresult cuErr; - CUcontext desiredCtx; - CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desiredCtx, cudaDevice); + CUdevice device = ih.get_native_device(); CUstream streamId = get_stream(queue); cublasStatus_t err; - auto it = handle_helper.cublas_handle_mapper_.find(desiredCtx); - if (it != handle_helper.cublas_handle_mapper_.end()) { - if (it->second == nullptr) { - handle_helper.cublas_handle_mapper_.erase(it); - } - else { - auto handle = it->second->load(); - if (handle != nullptr) { - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; - } - else { - handle_helper.cublas_handle_mapper_.erase(it); - } - } + + if (handle_helper.cublas_handle_mapper_.count(device) > 0) { + cublasHandle_t handle = handle_helper.cublas_handle_mapper_[device]; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); + } + return handle; } cublasHandle_t handle; - CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(desiredCtx, new std::atomic(handle))); + std::make_pair(device, handle)); return handle; } diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index ab8ffff19..803a98f32 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -63,9 +63,8 @@ the handle must be destroyed when the context goes out of scope. This will bind **/ class CublasScopedContextHandler { - CUcontext original_; sycl::interop_handle& ih; - static thread_local cublas_handle handle_helper; + static thread_local cublas_handle handle_helper; CUstream get_stream(const sycl::queue& queue); sycl::context get_context(const sycl::queue& queue); From 46a266157b8f01dd9222e62a2eba543fc6c25478 Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Fri, 8 Nov 2024 17:13:57 +0000 Subject: [PATCH 6/7] Rolled back to the version with unordered map but without contextSetExtendedDeleter It seems the static thread_local unordered map needs to stay because of all the thread shenanigans. But we're removing the use of detail namespace in sycl since it's not necessary for correctness. --- src/blas/backends/cublas/cublas_handle.hpp | 19 ++--------- .../backends/cublas/cublas_scope_handle.cpp | 32 +++++++++---------- .../backends/cublas/cublas_scope_handle.hpp | 4 +-- .../cublas/cublas_scope_handle_hipsycl.cpp | 28 ++++++---------- .../cublas/cublas_scope_handle_hipsycl.hpp | 1 - src/blas/backends/cublas/cublas_task.hpp | 2 +- 6 files changed, 29 insertions(+), 57 deletions(-) diff --git a/src/blas/backends/cublas/cublas_handle.hpp b/src/blas/backends/cublas/cublas_handle.hpp index 83a76c927..ce455925f 100644 --- a/src/blas/backends/cublas/cublas_handle.hpp +++ b/src/blas/backends/cublas/cublas_handle.hpp @@ -18,7 +18,6 @@ **************************************************************************/ #ifndef CUBLAS_HANDLE_HPP #define CUBLAS_HANDLE_HPP -#include #include namespace oneapi { @@ -28,26 +27,12 @@ namespace cublas { template struct cublas_handle { - using handle_container_t = std::unordered_map*>; + using handle_container_t = std::unordered_map; handle_container_t cublas_handle_mapper_{}; ~cublas_handle() noexcept(false) { for (auto& handle_pair : cublas_handle_mapper_) { cublasStatus_t err; - if (handle_pair.second != nullptr) { - auto handle = handle_pair.second->exchange(nullptr); - if (handle != nullptr) { - CUBLAS_ERROR_FUNC(cublasDestroy, err, handle); - handle = nullptr; - } - else { - // if the handle is nullptr it means the handle was already - // destroyed by the ContextCallback and we're free to delete the - // atomic object. - delete handle_pair.second; - } - - handle_pair.second = nullptr; - } + CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second); } cublas_handle_mapper_.clear(); } diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 164c3c3aa..142c36217 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -33,32 +33,32 @@ namespace cublas { thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; -CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) - : ih(ih) {} +CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { CUdevice device = ih.get_native_device(); CUstream streamId = get_stream(queue); cublasStatus_t err; - if (handle_helper.cublas_handle_mapper_.count(device) > 0) { - cublasHandle_t handle = handle_helper.cublas_handle_mapper_[device]; - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; + auto it = handle_helper.cublas_handle_mapper_.find(device); + if (it != handle_helper.cublas_handle_mapper_.end()) { + cublasHandle_t nativeHandle = it->second; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); + } + return nativeHandle; } - cublasHandle_t handle; - CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); + cublasHandle_t nativeHandle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle); + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); - auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(device, handle)); + auto insert_iter = + handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle)); - return handle; + return nativeHandle; } CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) { diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index 803a98f32..28ca1f71a 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -24,10 +24,8 @@ #include #endif -#include #include #include -#include #include "cublas_helper.hpp" #include "cublas_handle.hpp" @@ -69,7 +67,7 @@ class CublasScopedContextHandler { sycl::context get_context(const sycl::queue& queue); public: - CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih); + CublasScopedContextHandler(sycl::interop_handle& ih); /** * @brief get_handle: creates the handle by implicitly impose the advice diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp index 03c282aed..908600d27 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp @@ -36,31 +36,21 @@ cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) cublasStatus_t err; auto it = handle_helper.cublas_handle_mapper_.find(current_device); if (it != handle_helper.cublas_handle_mapper_.end()) { - if (it->second == nullptr) { - handle_helper.cublas_handle_mapper_.erase(it); - } - else { - auto handle = it->second->load(); - if (handle != nullptr) { - cudaStream_t currentStreamId; - CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); - if (currentStreamId != streamId) { - CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - } - return handle; - } - else { - handle_helper.cublas_handle_mapper_.erase(it); - } + cublasHandle_t handle = it->second; + cudaStream_t currentStreamId; + CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, ¤tStreamId); + if (currentStreamId != streamId) { + CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); } + return handle; } cublasHandle_t handle; CUBLAS_ERROR_FUNC(cublasCreate, err, &handle); CUBLAS_ERROR_FUNC(cublasSetStream, err, handle, streamId); - auto insert_iter = handle_helper.cublas_handle_mapper_.insert( - std::make_pair(current_device, new std::atomic(handle))); + auto insert_iter = + handle_helper.cublas_handle_mapper_.insert(std::make_pair(current_device, handle)); return handle; } @@ -71,4 +61,4 @@ CUstream CublasScopedContextHandler::get_stream(const sycl::queue& queue) { } // namespace cublas } // namespace blas } // namespace mkl -} // namespace oneapi \ No newline at end of file +} // namespace oneapi diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp index 9e1eb89e5..7d218e355 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp @@ -25,7 +25,6 @@ #endif #include #include -#include #include "cublas_helper.hpp" #include "cublas_handle.hpp" namespace oneapi { diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index f4b530ddd..ae95e6eb1 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -60,7 +60,7 @@ static inline void host_task_internal(H& cgh, sycl::queue queue, F f) { #else cgh.host_task([f, queue](sycl::interop_handle ih) { #endif - auto sc = CublasScopedContextHandler(queue, ih); + auto sc = CublasScopedContextHandler(ih); f(sc); }); } From 59963204393a980f2de2cefd63f11afa8177f04e Mon Sep 17 00:00:00 2001 From: Konrad Kusiak Date: Fri, 8 Nov 2024 18:00:35 +0000 Subject: [PATCH 7/7] We need to set the context properly before destroying cublasHandles --- src/blas/backends/cublas/cublas_handle.hpp | 13 +++++++++++-- src/blas/backends/cublas/cublas_scope_handle.cpp | 3 +-- src/blas/backends/cublas/cublas_scope_handle.hpp | 2 +- .../backends/cublas/cublas_scope_handle_hipsycl.cpp | 4 ++-- .../backends/cublas/cublas_scope_handle_hipsycl.hpp | 2 +- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/blas/backends/cublas/cublas_handle.hpp b/src/blas/backends/cublas/cublas_handle.hpp index ce455925f..8b77282df 100644 --- a/src/blas/backends/cublas/cublas_handle.hpp +++ b/src/blas/backends/cublas/cublas_handle.hpp @@ -19,18 +19,27 @@ #ifndef CUBLAS_HANDLE_HPP #define CUBLAS_HANDLE_HPP #include +#include "cublas_helper.hpp" namespace oneapi { namespace mkl { namespace blas { namespace cublas { -template struct cublas_handle { - using handle_container_t = std::unordered_map; + using handle_container_t = std::unordered_map; handle_container_t cublas_handle_mapper_{}; ~cublas_handle() noexcept(false) { + CUresult err; + CUcontext original; + CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original); for (auto& handle_pair : cublas_handle_mapper_) { + CUcontext desired; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, handle_pair.first); + if (original != desired) { + // Sets the desired context as the active one for the thread in order to destroy its corresponding cublasHandle_t. + CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); + } cublasStatus_t err; CUBLAS_ERROR_FUNC(cublasDestroy, err, handle_pair.second); } diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 142c36217..812d89d31 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -30,8 +30,7 @@ namespace cublas { * takes place if no other element in the container has a key equivalent to * the one being emplaced (keys in a map container are unique). */ -thread_local cublas_handle CublasScopedContextHandler::handle_helper = - cublas_handle{}; +thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {} diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index 28ca1f71a..2f6027478 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -62,7 +62,7 @@ the handle must be destroyed when the context goes out of scope. This will bind class CublasScopedContextHandler { sycl::interop_handle& ih; - static thread_local cublas_handle handle_helper; + static thread_local cublas_handle handle_helper; CUstream get_stream(const sycl::queue& queue); sycl::context get_context(const sycl::queue& queue); diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp index 908600d27..8822151dd 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.cpp @@ -24,14 +24,14 @@ namespace mkl { namespace blas { namespace cublas { -thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; +thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl::interop_handle& ih) : interop_h(ih) {} cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue& queue) { sycl::device device = queue.get_device(); - int current_device = interop_h.get_native_device(); + CUdevice current_device = interop_h.get_native_device(); CUstream streamId = get_stream(queue); cublasStatus_t err; auto it = handle_helper.cublas_handle_mapper_.find(current_device); diff --git a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp index 7d218e355..84b28e0fd 100644 --- a/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle_hipsycl.hpp @@ -59,7 +59,7 @@ the handle must be destroyed when the context goes out of scope. This will bind class CublasScopedContextHandler { sycl::interop_handle interop_h; - static thread_local cublas_handle handle_helper; + static thread_local cublas_handle handle_helper; sycl::context get_context(const sycl::queue& queue); CUstream get_stream(const sycl::queue& queue);