Skip to content

Commit

Permalink
Update trmm to use trmm_outofplace with rocblas
Browse files Browse the repository at this point in the history
  • Loading branch information
msimberg committed Sep 8, 2023
1 parent 52704e9 commit 1e46a36
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions include/dlaf/blas/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,31 @@ DLAF_MAKE_GPUBLAS_SYHE_OP(Her2k, r2k);

DLAF_MAKE_GPUBLAS_SYHE_OP(Herk, rk);

#if defined(DLAF_WITH_HIP)
#if defined(DLAF_WITH_CUDA)
DLAF_MAKE_GPUBLAS_OP(Trmm, trmm);
#elif defined(DLAF_WITH_HIP)

#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
#endif

// TODO: What will be the upper bound?
#if HIP_VERSION < 50000000 || 50700000 <= HIP_VERSION
DLAF_MAKE_GPUBLAS_OP(Trmm, trmm);
#if defined(DLAF_WITH_HIP)
#else
DLAF_MAKE_GPUBLAS_OP(Trmm, trmm_outofplace);
#endif

#if defined(__clang__)
#pragma clang diagnostic pop
#elif defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

#endif

DLAF_MAKE_GPUBLAS_OP(Trsm, trsm);
Expand Down Expand Up @@ -406,12 +415,7 @@ void her2k(cublasHandle_t handle, const blas::Uplo uplo, blas::Op op, const T al
using util::blasToCublas;
using util::blasToCublasCast;
auto s = getHer2kSizes(op, a, b, c);
#ifdef DLAF_WITH_HIP
// Note:
// Up to date the fix for this problem is on rocblas@develop, which should be included in
// the next 5.2.0 release.
//
// https://github.com/ROCmSoftwarePlatform/rocBLAS/commit/e714f1f29ab71dfcdfa4add4462548b34d1cd9e8
#if defined(DLAF_WITH_HIP) && HIP_VERSION < 50200000
if (!isComplex_v<T> && op == blas::Op::ConjTrans)
op = blas::Op::Trans;
#endif
Expand Down Expand Up @@ -445,7 +449,7 @@ void trmm(cublasHandle_t handle, const blas::Side side, const blas::Uplo uplo, c
gpublas::internal::Trmm<T>::call(handle, blasToCublas(side), blasToCublas(uplo), blasToCublas(op),
blasToCublas(diag), to_int(s.m), to_int(s.n),
blasToCublasCast(&alpha), blasToCublasCast(a.ptr()), to_int(a.ld()),
#ifdef DLAF_WITH_CUDA
#if defined(DLAF_WITH_CUDA) || (defined(DLAF_WITH_HIP) && HIP_VERSION >= 50000000)
blasToCublasCast(b.ptr()), to_int(b.ld()),
#endif
blasToCublasCast(b.ptr()), to_int(b.ld()));
Expand All @@ -460,15 +464,16 @@ void trmm3(cublasHandle_t handle, const blas::Side side, const blas::Uplo uplo,
auto s = tile::internal::getTrmm3Sizes(side, a, b, c);
DLAF_ASSERT(b.ptr() == nullptr || b.ptr() != c.ptr(), b.ptr(), c.ptr());

#ifdef DLAF_WITH_HIP
#if defined(DLAF_WITH_HIP) && HIP_VERSION < 50000000
whip::stream_t stream;
DLAF_GPUBLAS_CHECK_ERROR(cublasGetStream(handle, &stream));
matrix::internal::copy(b, c, stream);
#endif

gpublas::internal::Trmm<T>::call(handle, blasToCublas(side), blasToCublas(uplo), blasToCublas(op),
blasToCublas(diag), to_int(s.m), to_int(s.n),
blasToCublasCast(&alpha), blasToCublasCast(a.ptr()), to_int(a.ld()),
#ifdef DLAF_WITH_CUDA
#if defined(DLAF_WITH_CUDA) || (defined(DLAF_WITH_HIP) && HIP_VERSION >= 50000000)
blasToCublasCast(b.ptr()), to_int(b.ld()),
#endif
blasToCublasCast(c.ptr()), to_int(c.ld()));
Expand Down

0 comments on commit 1e46a36

Please sign in to comment.