diff --git a/include/hydrogen/blas/BLAS_Common.hpp b/include/hydrogen/blas/BLAS_Common.hpp index 0dbf2d6967..d565d9241f 100644 --- a/include/hydrogen/blas/BLAS_Common.hpp +++ b/include/hydrogen/blas/BLAS_Common.hpp @@ -138,6 +138,16 @@ namespace gpu_blas { /** @brief Set the pointer mode of the underlying library. */ void SetPointerMode(PointerMode mode); + +/** @brief Request that the underlying library use specialized tensor + * instructions. + * + * This is not a guarantee that such operations are available or will + * be used. However, if the library/hardware does expose such + * features, this will suggest to the library that they be used + * whenever possible. + */ +void RequestTensorOperations(); } }// namespace hydrogen #endif // HYDROGEN_BLAS_COMMON_HPP_ diff --git a/include/hydrogen/blas/GPU_BLAS_decl.hpp b/include/hydrogen/blas/GPU_BLAS_decl.hpp index 2825b961f2..b329832889 100644 --- a/include/hydrogen/blas/GPU_BLAS_decl.hpp +++ b/include/hydrogen/blas/GPU_BLAS_decl.hpp @@ -306,7 +306,7 @@ void Dot(SizeT num_entries, * @param num_entries The number of entries in X. * @param X The vector (device memory). * @param stride_X The stride of X. - * @param result The result of the dot product (host or device memory). + * @param result The result of the norm (host or device memory). * @param[in] syncinfo The synchronization information for this * operation. * @@ -538,7 +538,6 @@ void GemmStridedBatched( SizeT batchCount, SyncInfo const& syncinfo); - ///@} /** @name BLAS-like Extension Routines */ ///@{ diff --git a/include/hydrogen/blas/GPU_BLAS_impl.hpp b/include/hydrogen/blas/GPU_BLAS_impl.hpp index e5c6675543..2a79364dda 100644 --- a/include/hydrogen/blas/GPU_BLAS_impl.hpp +++ b/include/hydrogen/blas/GPU_BLAS_impl.hpp @@ -177,6 +177,44 @@ void Copy2DImpl(SizeT nrows, SizeT ncols, reinterpret_cast(B), ldb); } +template >> +void DotImpl(SizeT num_entries, + T const* X, SizeT stride_X, + T const* Y, SizeT stride_Y, + T* result, + SyncInfo const& si) +{ + using NTP = MakePointer>; + using CNTP = MakePointerToConst>; + + SyncManager mgr(GetLibraryHandle(), si); + gpu_blas_impl::Dot( + GetLibraryHandle(), + num_entries, + reinterpret_cast(X), stride_X, + reinterpret_cast(Y), stride_Y, + reinterpret_cast(result)); +} + +template >> +void Nrm2Impl(SizeT num_entries, + T const* X, SizeT stride_X, + T* result, + SyncInfo const& si) +{ + using NTP = MakePointer>; + using CNTP = MakePointerToConst>; + + SyncManager mgr(GetLibraryHandle(), si); + gpu_blas_impl::Nrm2( + GetLibraryHandle(), + num_entries, + reinterpret_cast(X), stride_X, + reinterpret_cast(result)); +} + template >> void ScaleImpl(SizeT num_entries, @@ -384,6 +422,36 @@ void GemmImpl( reinterpret_cast(C), ToSizeT(ldc)); } +template >> +void GemmStridedBatchedImpl( + TransposeMode transpA, TransposeMode transpB, + SizeT m, SizeT n, SizeT k, + T const& alpha, + T const* A, SizeT lda, StrideT strideA, + T const* B, SizeT ldb, StrideT strideB, + T const& beta, + T* C, SizeT ldc, StrideT strideC, + SizeT batchCount, + SyncInfo const& si) +{ + using NTP = MakePointer>; + using CNTP = MakePointerToConst>; + + SyncManager mgr(GetLibraryHandle(), si); + gpu_blas_impl::GemmStridedBatched( + GetLibraryHandle(), + ToNativeTransposeMode(transpA), + ToNativeTransposeMode(transpB), + ToSizeT(m), ToSizeT(n), ToSizeT(k), + &alpha, + reinterpret_cast(A), ToSizeT(lda), ToSizeT(strideA), + reinterpret_cast(B), ToSizeT(ldb), ToSizeT(strideB), + &beta, + reinterpret_cast(C), ToSizeT(ldc), ToSizeT(strideC), + ToSizeT(batchCount)); +} + template >> void DgmmImpl(SideMode side, @@ -514,6 +582,29 @@ void Copy2DStridedImpl( si); } +template >, + typename=void> +void DotImpl(SizeT, T const*, SizeT, T const*, SizeT, T*, + SyncInfo const&) +{ + std::ostringstream oss; + oss << "No valid implementation of DOT for T=" + << TypeTraits::Name(); + throw std::logic_error(oss.str()); +} + +template >, + typename=void> +void Nrm2Impl(SizeT, T const*, SizeT, T*, SyncInfo const&) +{ + std::ostringstream oss; + oss << "No valid implementation of NRM2 for T=" + << TypeTraits::Name(); + throw std::logic_error(oss.str()); +} + template >, typename=void> @@ -691,6 +782,27 @@ void GemmImpl( throw std::logic_error(oss.str()); } +template < + typename T, typename SizeT, typename StrideT, + typename=EnableUnless>, + typename=void> +void GemmStridedBatchedImpl( + TransposeMode, TransposeMode, + SizeT, SizeT, SizeT, + T const&, + T const*, SizeT, StrideT, + T const*, SizeT, StrideT, + T const&, + T*, SizeT, StrideT, + SizeT, + SyncInfo const&) +{ + std::ostringstream oss; + oss << "No valid implementation of GEMMSTRIDEDBATCHED for T=" + << TypeTraits::Name(); + throw std::logic_error(oss.str()); +} + template >, typename=void> @@ -788,6 +900,25 @@ void Copy(SizeT num_rows, SizeT num_cols, B, row_stride_B, ldb, si); } +template +void Dot(SizeT num_entries, + T const* X, SizeT stride_X, + T const* Y, SizeT stride_Y, + T* result, + SyncInfo const& syncinfo) +{ + details::DotImpl(num_entries, X, stride_X, Y, stride_Y, result, syncinfo); +} + +template +void Nrm2(SizeT num_entries, + T const* X, SizeT stride_X, + T* result, + SyncInfo const& syncinfo) +{ + details::Nrm2Impl(num_entries, X, stride_X, result, syncinfo); +} + template void Scale(SizeT size, T const& alpha, @@ -895,6 +1026,29 @@ void Gemm( beta, C, ldc, si); } +template +void GemmStridedBatched( + TransposeMode transpA, TransposeMode transpB, + SizeT m, SizeT n, SizeT k, + T const& alpha, + T const* A, SizeT lda, StrideT strideA, + T const* B, SizeT ldb, StrideT strideB, + T const& beta, + T* C, SizeT ldc, StrideT strideC, + SizeT batchCount, + SyncInfo const& si) +{ + details::GemmStridedBatchedImpl(transpA, transpB, + m, n, k, + alpha, + A, lda, strideA, + B, ldb, strideB, + beta, + C, ldc, strideC, + batchCount, + si); +} + // // BLAS-like Extension Routines // diff --git a/include/hydrogen/device/gpu/cuda/cuBLAS_API.hpp b/include/hydrogen/device/gpu/cuda/cuBLAS_API.hpp index c946ba280b..ec23e17a33 100644 --- a/include/hydrogen/device/gpu/cuda/cuBLAS_API.hpp +++ b/include/hydrogen/device/gpu/cuda/cuBLAS_API.hpp @@ -29,13 +29,13 @@ namespace cublas int n, \ ScalarType const* X, int incx, \ ScalarType const* Y, int incy, \ - ScalarType& output) + ScalarType* output) #define ADD_NRM2_DECL(ScalarType) \ void Nrm2(cublasHandle_t handle, \ int n, \ ScalarType const* X, int incx, \ - ScalarType& output) + ScalarType* output) #define ADD_SCALE_DECL(ScalarType) \ void Scale(cublasHandle_t handle, \ diff --git a/src/hydrogen/device/cuBLAS.cpp b/src/hydrogen/device/cuBLAS.cpp index e6336cec63..57fd4d7ece 100644 --- a/src/hydrogen/device/cuBLAS.cpp +++ b/src/hydrogen/device/cuBLAS.cpp @@ -155,5 +155,12 @@ void SetPointerMode(PointerMode mode) ? CUBLAS_POINTER_MODE_HOST : CUBLAS_POINTER_MODE_DEVICE))); } +void RequestTensorOperations() +{ + H_CHECK_CUBLAS( + cublasSetMathMode(cublas::GetLibraryHandle(), + CUBLAS_TENSOR_OP_MATH)); +} + }// namespace gpu_blas }// namespace hydrogen diff --git a/src/hydrogen/device/cuBLAS_API.cpp b/src/hydrogen/device/cuBLAS_API.cpp index 5b493998f4..fd6714bad5 100644 --- a/src/hydrogen/device/cuBLAS_API.cpp +++ b/src/hydrogen/device/cuBLAS_API.cpp @@ -39,7 +39,7 @@ void Dot(cublasHandle_t handle, int n, __half const* X, int incx, __half const* Y, int incy, - __half& output) + __half* output) { H_CHECK_CUBLAS( cublasDotEx( @@ -47,7 +47,7 @@ void Dot(cublasHandle_t handle, n, X, /*xtype=*/CUDA_R_16F, incx, Y, /*ytype=*/CUDA_R_16F, incy, - &output, + output, /*resulttype=*/CUDA_R_16F, /*executiontype=*/CUDA_R_32F)); } @@ -55,14 +55,14 @@ void Dot(cublasHandle_t handle, void Nrm2(cublasHandle_t handle, int n, __half const* X, int incx, - __half& output) + __half* output) { H_CHECK_CUBLAS( cublasNrm2Ex( handle, n, X, /*xtype=*/CUDA_R_16F, incx, - &output, + output, /*resulttype=*/CUDA_R_16F, /*executiontype=*/CUDA_R_32F)); } @@ -133,6 +133,8 @@ struct RealTypeT template using RealType = typename RealTypeT::type; +// For complex DOT, assume most use-cases will want the inner +// producted in a complex vector space. #define ADD_COMPLEX_DOT_IMPL(ScalarType, TypeChar) \ void Dotu(cublasHandle_t handle, \ int n, ScalarType const* X, int incx, \ @@ -145,14 +147,21 @@ using RealType = typename RealTypeT::type; n, X, incx, Y, incy, output)); \ } \ void Dotc(cublasHandle_t handle, \ - int n, ScalarType const* X, int incx, \ - ScalarType const* Y, int incy, \ - ScalarType* output) \ + int n, ScalarType const* X, int incx, \ + ScalarType const* Y, int incy, \ + ScalarType* output) \ { \ H_CHECK_CUBLAS( \ cublas ## TypeChar ## dotc( \ handle, \ n, X, incx, Y, incy, output)); \ + } \ + void Dot(cublasHandle_t handle, \ + int n, ScalarType const* X, int incx, \ + ScalarType const* Y, int incy, \ + ScalarType* output) \ + { \ + Dotc(handle, n, X, incx, Y, incy, output); \ } #define ADD_NRM2_IMPL(ScalarType, TypeChar) \ diff --git a/src/hydrogen/device/rocBLAS.cpp b/src/hydrogen/device/rocBLAS.cpp index a5dd66addc..5e3106d141 100644 --- a/src/hydrogen/device/rocBLAS.cpp +++ b/src/hydrogen/device/rocBLAS.cpp @@ -143,6 +143,11 @@ void SetPointerMode(PointerMode mode) ? rocblas_pointer_mode_host : rocblas_pointer_mode_device))); } +void RequestTensorOperations() +{ +// Nothing to do here. +} + }// namespace gpu_blas }// namespace hydrogen