Skip to content

Commit

Permalink
Patch cuBLAS implementations (elemental#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
benson31 authored Sep 29, 2020
1 parent c79787c commit 09128aa
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 11 deletions.
10 changes: 10 additions & 0 deletions include/hydrogen/blas/BLAS_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ namespace gpu_blas
{
/** @brief Set the pointer mode of the underlying library. */
void SetPointerMode(PointerMode mode);

/** @brief Request that the underlying library use specialized tensor
* instructions.
*
* This is not a guarantee that such operations are available or will
* be used. However, if the library/hardware does expose such
* features, this will suggest to the library that they be used
* whenever possible.
*/
void RequestTensorOperations();
}
}// namespace hydrogen
#endif // HYDROGEN_BLAS_COMMON_HPP_
3 changes: 1 addition & 2 deletions include/hydrogen/blas/GPU_BLAS_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ void Dot(SizeT num_entries,
* @param num_entries The number of entries in X.
* @param X The vector (device memory).
* @param stride_X The stride of X.
* @param result The result of the dot product (host or device memory).
* @param result The result of the norm (host or device memory).
* @param[in] syncinfo The synchronization information for this
* operation.
*
Expand Down Expand Up @@ -538,7 +538,6 @@ void GemmStridedBatched(
SizeT batchCount,
SyncInfo<Device::GPU> const& syncinfo);


///@}
/** @name BLAS-like Extension Routines */
///@{
Expand Down
154 changes: 154 additions & 0 deletions include/hydrogen/blas/GPU_BLAS_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,44 @@ void Copy2DImpl(SizeT nrows, SizeT ncols,
reinterpret_cast<NTP>(B), ldb);
}

template <typename T, typename SizeT,
typename=EnableWhen<IsSupportedType<T,BLAS_Op::DOT>>>
void DotImpl(SizeT num_entries,
T const* X, SizeT stride_X,
T const* Y, SizeT stride_Y,
T* result,
SyncInfo<Device::GPU> const& si)
{
using NTP = MakePointer<NativeType<T>>;
using CNTP = MakePointerToConst<NativeType<T>>;

SyncManager mgr(GetLibraryHandle(), si);
gpu_blas_impl::Dot(
GetLibraryHandle(),
num_entries,
reinterpret_cast<CNTP>(X), stride_X,
reinterpret_cast<CNTP>(Y), stride_Y,
reinterpret_cast<NTP>(result));
}

template <typename T, typename SizeT,
typename=EnableWhen<IsSupportedType<T,BLAS_Op::DOT>>>
void Nrm2Impl(SizeT num_entries,
T const* X, SizeT stride_X,
T* result,
SyncInfo<Device::GPU> const& si)
{
using NTP = MakePointer<NativeType<T>>;
using CNTP = MakePointerToConst<NativeType<T>>;

SyncManager mgr(GetLibraryHandle(), si);
gpu_blas_impl::Nrm2(
GetLibraryHandle(),
num_entries,
reinterpret_cast<CNTP>(X), stride_X,
reinterpret_cast<NTP>(result));
}

template <typename T, typename SizeT,
typename=EnableWhen<IsSupportedType<T,BLAS_Op::SCAL>>>
void ScaleImpl(SizeT num_entries,
Expand Down Expand Up @@ -384,6 +422,36 @@ void GemmImpl(
reinterpret_cast<NTP>(C), ToSizeT(ldc));
}

template <typename T, typename SizeT, typename StrideT,
typename=EnableWhen<IsSupportedType<T, BLAS_Op::GEMMSTRIDEDBATCHED>>>
void GemmStridedBatchedImpl(
TransposeMode transpA, TransposeMode transpB,
SizeT m, SizeT n, SizeT k,
T const& alpha,
T const* A, SizeT lda, StrideT strideA,
T const* B, SizeT ldb, StrideT strideB,
T const& beta,
T* C, SizeT ldc, StrideT strideC,
SizeT batchCount,
SyncInfo<Device::GPU> const& si)
{
using NTP = MakePointer<NativeType<T>>;
using CNTP = MakePointerToConst<NativeType<T>>;

SyncManager mgr(GetLibraryHandle(), si);
gpu_blas_impl::GemmStridedBatched(
GetLibraryHandle(),
ToNativeTransposeMode(transpA),
ToNativeTransposeMode(transpB),
ToSizeT(m), ToSizeT(n), ToSizeT(k),
&alpha,
reinterpret_cast<CNTP>(A), ToSizeT(lda), ToSizeT(strideA),
reinterpret_cast<CNTP>(B), ToSizeT(ldb), ToSizeT(strideB),
&beta,
reinterpret_cast<NTP>(C), ToSizeT(ldc), ToSizeT(strideC),
ToSizeT(batchCount));
}

template <typename T, typename SizeT,
typename=EnableWhen<IsSupportedType<T, BLAS_Op::DGMM>>>
void DgmmImpl(SideMode side,
Expand Down Expand Up @@ -514,6 +582,29 @@ void Copy2DStridedImpl(
si);
}

template <typename T, typename SizeT,
typename=EnableUnless<IsSupportedType<T,BLAS_Op::DOT>>,
typename=void>
void DotImpl(SizeT, T const*, SizeT, T const*, SizeT, T*,
SyncInfo<Device::GPU> const&)
{
std::ostringstream oss;
oss << "No valid implementation of DOT for T="
<< TypeTraits<T>::Name();
throw std::logic_error(oss.str());
}

template <typename T, typename SizeT,
typename=EnableUnless<IsSupportedType<T,BLAS_Op::DOT>>,
typename=void>
void Nrm2Impl(SizeT, T const*, SizeT, T*, SyncInfo<Device::GPU> const&)
{
std::ostringstream oss;
oss << "No valid implementation of NRM2 for T="
<< TypeTraits<T>::Name();
throw std::logic_error(oss.str());
}

template <typename T, typename SizeT,
typename=EnableUnless<IsSupportedType<T,BLAS_Op::SCAL>>,
typename=void>
Expand Down Expand Up @@ -691,6 +782,27 @@ void GemmImpl(
throw std::logic_error(oss.str());
}

template <
typename T, typename SizeT, typename StrideT,
typename=EnableUnless<IsSupportedType<T, BLAS_Op::GEMMSTRIDEDBATCHED>>,
typename=void>
void GemmStridedBatchedImpl(
TransposeMode, TransposeMode,
SizeT, SizeT, SizeT,
T const&,
T const*, SizeT, StrideT,
T const*, SizeT, StrideT,
T const&,
T*, SizeT, StrideT,
SizeT,
SyncInfo<Device::GPU> const&)
{
std::ostringstream oss;
oss << "No valid implementation of GEMMSTRIDEDBATCHED for T="
<< TypeTraits<T>::Name();
throw std::logic_error(oss.str());
}

template <typename T, typename SizeT,
typename=EnableUnless<IsSupportedType<T,BLAS_Op::DGMM>>,
typename=void>
Expand Down Expand Up @@ -788,6 +900,25 @@ void Copy(SizeT num_rows, SizeT num_cols,
B, row_stride_B, ldb, si);
}

template <typename T, typename SizeT>
void Dot(SizeT num_entries,
T const* X, SizeT stride_X,
T const* Y, SizeT stride_Y,
T* result,
SyncInfo<Device::GPU> const& syncinfo)
{
details::DotImpl(num_entries, X, stride_X, Y, stride_Y, result, syncinfo);
}

template <typename T, typename SizeT>
void Nrm2(SizeT num_entries,
T const* X, SizeT stride_X,
T* result,
SyncInfo<Device::GPU> const& syncinfo)
{
details::Nrm2Impl(num_entries, X, stride_X, result, syncinfo);
}

template <typename T, typename SizeT>
void Scale(SizeT size,
T const& alpha,
Expand Down Expand Up @@ -895,6 +1026,29 @@ void Gemm(
beta, C, ldc, si);
}

template <typename T, typename SizeT, typename StrideT>
void GemmStridedBatched(
TransposeMode transpA, TransposeMode transpB,
SizeT m, SizeT n, SizeT k,
T const& alpha,
T const* A, SizeT lda, StrideT strideA,
T const* B, SizeT ldb, StrideT strideB,
T const& beta,
T* C, SizeT ldc, StrideT strideC,
SizeT batchCount,
SyncInfo<Device::GPU> const& si)
{
details::GemmStridedBatchedImpl(transpA, transpB,
m, n, k,
alpha,
A, lda, strideA,
B, ldb, strideB,
beta,
C, ldc, strideC,
batchCount,
si);
}

//
// BLAS-like Extension Routines
//
Expand Down
4 changes: 2 additions & 2 deletions include/hydrogen/device/gpu/cuda/cuBLAS_API.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ namespace cublas
int n, \
ScalarType const* X, int incx, \
ScalarType const* Y, int incy, \
ScalarType& output)
ScalarType* output)

#define ADD_NRM2_DECL(ScalarType) \
void Nrm2(cublasHandle_t handle, \
int n, \
ScalarType const* X, int incx, \
ScalarType& output)
ScalarType* output)

#define ADD_SCALE_DECL(ScalarType) \
void Scale(cublasHandle_t handle, \
Expand Down
7 changes: 7 additions & 0 deletions src/hydrogen/device/cuBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,12 @@ void SetPointerMode(PointerMode mode)
? CUBLAS_POINTER_MODE_HOST
: CUBLAS_POINTER_MODE_DEVICE)));
}
void RequestTensorOperations()
{
H_CHECK_CUBLAS(
cublasSetMathMode(cublas::GetLibraryHandle(),
CUBLAS_TENSOR_OP_MATH));
}

}// namespace gpu_blas
}// namespace hydrogen
23 changes: 16 additions & 7 deletions src/hydrogen/device/cuBLAS_API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,30 @@ void Dot(cublasHandle_t handle,
int n,
__half const* X, int incx,
__half const* Y, int incy,
__half& output)
__half* output)
{
H_CHECK_CUBLAS(
cublasDotEx(
handle,
n,
X, /*xtype=*/CUDA_R_16F, incx,
Y, /*ytype=*/CUDA_R_16F, incy,
&output,
output,
/*resulttype=*/CUDA_R_16F,
/*executiontype=*/CUDA_R_32F));
}

void Nrm2(cublasHandle_t handle,
int n,
__half const* X, int incx,
__half& output)
__half* output)
{
H_CHECK_CUBLAS(
cublasNrm2Ex(
handle,
n,
X, /*xtype=*/CUDA_R_16F, incx,
&output,
output,
/*resulttype=*/CUDA_R_16F,
/*executiontype=*/CUDA_R_32F));
}
Expand Down Expand Up @@ -133,6 +133,8 @@ struct RealTypeT<cuDoubleComplex>
template <typename T>
using RealType = typename RealTypeT<T>::type;

// For complex DOT, assume most use-cases will want the inner
// producted in a complex vector space.
#define ADD_COMPLEX_DOT_IMPL(ScalarType, TypeChar) \
void Dotu(cublasHandle_t handle, \
int n, ScalarType const* X, int incx, \
Expand All @@ -145,14 +147,21 @@ using RealType = typename RealTypeT<T>::type;
n, X, incx, Y, incy, output)); \
} \
void Dotc(cublasHandle_t handle, \
int n, ScalarType const* X, int incx, \
ScalarType const* Y, int incy, \
ScalarType* output) \
int n, ScalarType const* X, int incx, \
ScalarType const* Y, int incy, \
ScalarType* output) \
{ \
H_CHECK_CUBLAS( \
cublas ## TypeChar ## dotc( \
handle, \
n, X, incx, Y, incy, output)); \
} \
void Dot(cublasHandle_t handle, \
int n, ScalarType const* X, int incx, \
ScalarType const* Y, int incy, \
ScalarType* output) \
{ \
Dotc(handle, n, X, incx, Y, incy, output); \
}

#define ADD_NRM2_IMPL(ScalarType, TypeChar) \
Expand Down
5 changes: 5 additions & 0 deletions src/hydrogen/device/rocBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ void SetPointerMode(PointerMode mode)
? rocblas_pointer_mode_host
: rocblas_pointer_mode_device)));
}
void RequestTensorOperations()
{
// Nothing to do here.
}

}// namespace gpu_blas

}// namespace hydrogen

0 comments on commit 09128aa

Please sign in to comment.