diff --git a/include/dlaf/blas/tile.h b/include/dlaf/blas/tile.h index d534f93d5d..1f1f8ee58c 100644 --- a/include/dlaf/blas/tile.h +++ b/include/dlaf/blas/tile.h @@ -28,130 +28,9 @@ #include #ifdef DLAF_WITH_GPU -#include - #include -#include +#include #include - -#ifdef DLAF_WITH_HIP - -#define DLAF_GET_ROCBLAS_WORKSPACE(f) \ - [&]() { \ - std::size_t workspace_size; \ - DLAF_GPUBLAS_CHECK_ERROR( \ - rocblas_start_device_memory_size_query(static_cast(handle))); \ - DLAF_ROCBLAS_WORKSPACE_CHECK_ERROR(rocblas_##f(handle, std::forward(args)...)); \ - DLAF_GPUBLAS_CHECK_ERROR(rocblas_stop_device_memory_size_query(static_cast(handle), \ - &workspace_size)); \ - return ::dlaf::memory::MemoryView(to_int(workspace_size)); \ - }(); - -namespace dlaf::tile::internal { -inline void extendROCBlasWorkspace(cublasHandle_t handle, - ::dlaf::memory::MemoryView&& workspace) { - whip::stream_t stream; - DLAF_GPUBLAS_CHECK_ERROR(cublasGetStream(handle, &stream)); - auto f = [workspace = std::move(workspace)](whip::error_t status) { whip::check_error(status); }; - pika::cuda::experimental::detail::add_event_callback(std::move(f), stream); -} -} - -#define DLAF_DEFINE_GPUBLAS_OP(Name, Type, f) \ - template <> \ - struct Name { \ - template \ - static void call(cublasHandle_t handle, Args&&... args) { \ - auto workspace = DLAF_GET_ROCBLAS_WORKSPACE(f); \ - DLAF_GPUBLAS_CHECK_ERROR(rocblas_set_workspace(static_cast(handle), workspace(), \ - to_sizet(workspace.size()))); \ - DLAF_GPUBLAS_CHECK_ERROR(rocblas_##f(handle, std::forward(args)...)); \ - DLAF_GPUBLAS_CHECK_ERROR(rocblas_set_workspace(static_cast(handle), nullptr, 0)); \ - ::dlaf::tile::internal::extendROCBlasWorkspace(handle, std::move(workspace)); \ - } \ - } - -#elif defined(DLAF_WITH_CUDA) - -#define DLAF_DEFINE_GPUBLAS_OP(Name, Type, f) \ - template <> \ - struct Name { \ - template \ - static void call(Args&&... args) { \ - DLAF_GPUBLAS_CHECK_ERROR(cublas##f##_v2(std::forward(args)...)); \ - } \ - } - -#endif - -#define DLAF_DECLARE_GPUBLAS_OP(Name) \ - template \ - struct Name - -#ifdef DLAF_WITH_HIP -#define DLAF_MAKE_GPUBLAS_OP(Name, f) \ - DLAF_DECLARE_GPUBLAS_OP(Name); \ - DLAF_DEFINE_GPUBLAS_OP(Name, float, s##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, double, d##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, c##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, z##f) - -#define DLAF_MAKE_GPUBLAS_SYHE_OP(Name, f) \ - DLAF_DECLARE_GPUBLAS_OP(Name); \ - DLAF_DEFINE_GPUBLAS_OP(Name, float, ssy##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, double, dsy##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, che##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, zhe##f) - -#elif defined(DLAF_WITH_CUDA) -#define DLAF_MAKE_GPUBLAS_OP(Name, f) \ - DLAF_DECLARE_GPUBLAS_OP(Name); \ - DLAF_DEFINE_GPUBLAS_OP(Name, float, S##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, double, D##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, C##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, Z##f) - -#define DLAF_MAKE_GPUBLAS_SYHE_OP(Name, f) \ - DLAF_DECLARE_GPUBLAS_OP(Name); \ - DLAF_DEFINE_GPUBLAS_OP(Name, float, Ssy##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, double, Dsy##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, Che##f); \ - DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, Zhe##f) -#endif - -namespace dlaf::gpublas::internal { - -// Level 1 -DLAF_MAKE_GPUBLAS_OP(Axpy, axpy); - -// Level 2 -DLAF_MAKE_GPUBLAS_OP(Gemv, gemv); - -DLAF_MAKE_GPUBLAS_OP(Trmv, trmv); - -// Level 3 -DLAF_MAKE_GPUBLAS_OP(Gemm, gemm); - -DLAF_MAKE_GPUBLAS_SYHE_OP(Hemm, mm); - -DLAF_MAKE_GPUBLAS_SYHE_OP(Her2k, r2k); - -DLAF_MAKE_GPUBLAS_SYHE_OP(Herk, rk); - -#if defined(DLAF_WITH_CUDA) -DLAF_MAKE_GPUBLAS_OP(Trmm, trmm); -#elif defined(DLAF_WITH_HIP) - -#if ROCBLAS_VERSION_MAJOR >= 3 && defined(ROCBLAS_V3) -DLAF_MAKE_GPUBLAS_OP(Trmm, trmm); -#else -DLAF_MAKE_GPUBLAS_OP(Trmm, trmm_outofplace); -#endif - -#endif - -DLAF_MAKE_GPUBLAS_OP(Trsm, trsm); -} #endif namespace dlaf { diff --git a/include/dlaf/gpu/blas/gpublas.h b/include/dlaf/gpu/blas/gpublas.h new file mode 100644 index 0000000000..8fb3e07832 --- /dev/null +++ b/include/dlaf/gpu/blas/gpublas.h @@ -0,0 +1,141 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2024, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// +#pragma once + +/// @file +/// Provides gpublas wrappers for BLAS operations. + +#include +#include + +#include + +#include +#include +#include + +#ifdef DLAF_WITH_HIP + +#define DLAF_GET_ROCBLAS_WORKSPACE(f) \ + [&]() { \ + std::size_t workspace_size; \ + DLAF_GPUBLAS_CHECK_ERROR( \ + rocblas_start_device_memory_size_query(static_cast(handle))); \ + DLAF_ROCBLAS_WORKSPACE_CHECK_ERROR(rocblas_##f(handle, std::forward(args)...)); \ + DLAF_GPUBLAS_CHECK_ERROR(rocblas_stop_device_memory_size_query(static_cast(handle), \ + &workspace_size)); \ + return ::dlaf::memory::MemoryView(to_int(workspace_size)); \ + }(); + +namespace dlaf::tile::internal { +inline void extendROCBlasWorkspace(cublasHandle_t handle, + ::dlaf::memory::MemoryView&& workspace) { + whip::stream_t stream; + DLAF_GPUBLAS_CHECK_ERROR(cublasGetStream(handle, &stream)); + auto f = [workspace = std::move(workspace)](whip::error_t status) { whip::check_error(status); }; + pika::cuda::experimental::detail::add_event_callback(std::move(f), stream); +} +} + +#define DLAF_DEFINE_GPUBLAS_OP(Name, Type, f) \ + template <> \ + struct Name { \ + template \ + static void call(cublasHandle_t handle, Args&&... args) { \ + auto workspace = DLAF_GET_ROCBLAS_WORKSPACE(f); \ + DLAF_GPUBLAS_CHECK_ERROR(rocblas_set_workspace(static_cast(handle), workspace(), \ + to_sizet(workspace.size()))); \ + DLAF_GPUBLAS_CHECK_ERROR(rocblas_##f(handle, std::forward(args)...)); \ + DLAF_GPUBLAS_CHECK_ERROR(rocblas_set_workspace(static_cast(handle), nullptr, 0)); \ + ::dlaf::tile::internal::extendROCBlasWorkspace(handle, std::move(workspace)); \ + } \ + } + +#elif defined(DLAF_WITH_CUDA) + +#define DLAF_DEFINE_GPUBLAS_OP(Name, Type, f) \ + template <> \ + struct Name { \ + template \ + static void call(Args&&... args) { \ + DLAF_GPUBLAS_CHECK_ERROR(cublas##f##_v2(std::forward(args)...)); \ + } \ + } + +#endif + +#define DLAF_DECLARE_GPUBLAS_OP(Name) \ + template \ + struct Name + +#ifdef DLAF_WITH_HIP +#define DLAF_MAKE_GPUBLAS_OP(Name, f) \ + DLAF_DECLARE_GPUBLAS_OP(Name); \ + DLAF_DEFINE_GPUBLAS_OP(Name, float, s##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, double, d##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, c##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, z##f) + +#define DLAF_MAKE_GPUBLAS_SYHE_OP(Name, f) \ + DLAF_DECLARE_GPUBLAS_OP(Name); \ + DLAF_DEFINE_GPUBLAS_OP(Name, float, ssy##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, double, dsy##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, che##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, zhe##f) + +#elif defined(DLAF_WITH_CUDA) +#define DLAF_MAKE_GPUBLAS_OP(Name, f) \ + DLAF_DECLARE_GPUBLAS_OP(Name); \ + DLAF_DEFINE_GPUBLAS_OP(Name, float, S##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, double, D##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, C##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, Z##f) + +#define DLAF_MAKE_GPUBLAS_SYHE_OP(Name, f) \ + DLAF_DECLARE_GPUBLAS_OP(Name); \ + DLAF_DEFINE_GPUBLAS_OP(Name, float, Ssy##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, double, Dsy##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, Che##f); \ + DLAF_DEFINE_GPUBLAS_OP(Name, std::complex, Zhe##f) +#endif + +namespace dlaf::gpublas::internal { + +// Level 1 +DLAF_MAKE_GPUBLAS_OP(Axpy, axpy); + +// Level 2 +DLAF_MAKE_GPUBLAS_OP(Gemv, gemv); + +DLAF_MAKE_GPUBLAS_OP(Trmv, trmv); + +// Level 3 +DLAF_MAKE_GPUBLAS_OP(Gemm, gemm); + +DLAF_MAKE_GPUBLAS_SYHE_OP(Hemm, mm); + +DLAF_MAKE_GPUBLAS_SYHE_OP(Her2k, r2k); + +DLAF_MAKE_GPUBLAS_SYHE_OP(Herk, rk); + +#if defined(DLAF_WITH_CUDA) +DLAF_MAKE_GPUBLAS_OP(Trmm, trmm); +#elif defined(DLAF_WITH_HIP) + +#if ROCBLAS_VERSION_MAJOR >= 3 && defined(ROCBLAS_V3) +DLAF_MAKE_GPUBLAS_OP(Trmm, trmm); +#else +DLAF_MAKE_GPUBLAS_OP(Trmm, trmm_outofplace); +#endif + +#endif + +DLAF_MAKE_GPUBLAS_OP(Trsm, trsm); +} diff --git a/miniapp/kernel/miniapp_larft_gemv.cpp b/miniapp/kernel/miniapp_larft_gemv.cpp index 9bde1a5470..083acfecb9 100644 --- a/miniapp/kernel/miniapp_larft_gemv.cpp +++ b/miniapp/kernel/miniapp_larft_gemv.cpp @@ -9,6 +9,7 @@ // #include +#include #include #include diff --git a/src/lapack/gpu/larft.cu b/src/lapack/gpu/larft.cu index 297bf42f8d..fc0c887e5e 100644 --- a/src/lapack/gpu/larft.cu +++ b/src/lapack/gpu/larft.cu @@ -11,13 +11,10 @@ #include #include -#include #include #include -#include -#include +#include #include -#include #include #include #include diff --git a/src/lapack/gpu/laset.cu b/src/lapack/gpu/laset.cu index dea96572d9..c0e68ca1d3 100644 --- a/src/lapack/gpu/laset.cu +++ b/src/lapack/gpu/laset.cu @@ -12,6 +12,7 @@ #include #include +#include #include #include