From fb93c3988a167e2fb8b636b6b5855705aa80a9d8 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Tue, 12 Jul 2022 04:06:32 +0000 Subject: [PATCH] [build] Split `.cu` to improve compile times (#81193) The goal is to speed up CUDA builds. I was looking at bulid times and found that we have large CUDA compilation units that take forever to compile and make parallelism less effective. This PR splits them up into different `.cu` files so we can parallelize compilation better. We've done this sort of thing in the past with some success. With a cold build, timing before: 5m42.019s, timing after: 4m30.275s. That's a speedup of 18.1% for me. Behaviorally this should be a no-op, I'm just moving code around. There is still more we can do here but I did most of the ones that are copypasta. The full list of remaining chonky compilation units is [here](https://gist.github.com/suo/0dc217733f40f59898a8cc4f60529d60). ## Details Here's a screenshot from a ninja trace, with the following command: ``` MAX_JOBS=64 CCACHE_DISABLE=1 TORCH_CUDA_ARCH_LIST=Ampere BUILD_CAFFE2_OPS=0 USE_FBGEMM=0 USE_DISTRIBUTED=0 USE_MKLDNN=0 BUILD_TEST=0 USE_GOLD_LINKER=1 USE_OPENMP=1 USE_NCCL=0 DEBUG=0 python setup.py develop ``` image ([source trace](https://gist.github.com/suo/5f5458f2630f9ab6dcbea6989e892195), which you can visualize in [perfetto](https://ui.perfetto.dev/)) After this PR, we get somewhat better utilization (although there is plenty still left to do): image ([source trace](https://gist.github.com/suo/5607335bcd4bd412d42b0c9334259184)) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81193 Approved by: https://github.com/cpuhrsch, https://github.com/malfet --- aten/src/ATen/cuda/cub-RadixSortKeys.cu | 108 +++ aten/src/ATen/cuda/cub-RadixSortPairs.cu | 93 +++ aten/src/ATen/cuda/cub.cu | 112 --- aten/src/ATen/native/cuda/Activation.cu | 710 ------------------ .../ATen/native/cuda/ActivationEluKernel.cu | 88 +++ .../ATen/native/cuda/ActivationGeluKernel.cu | 90 +++ .../ATen/native/cuda/ActivationGluKernel.cu | 143 ++++ .../native/cuda/ActivationHardshrinkKernel.cu | 41 + .../cuda/ActivationHardsigmoidKernel.cu | 76 ++ .../native/cuda/ActivationHardswishKernel.cu | 65 ++ .../native/cuda/ActivationHardtanhKernel.cu | 46 ++ .../native/cuda/ActivationLeakyReluKernel.cu | 64 ++ .../native/cuda/ActivationLogSigmoidKernel.cu | 66 ++ .../ATen/native/cuda/ActivationMishKernel.cu | 66 ++ .../ATen/native/cuda/ActivationPreluKernel.cu | 175 +++++ .../ATen/native/cuda/ActivationSiluKernel.cu | 61 ++ .../native/cuda/ActivationSoftplusKernel.cu | 76 ++ .../native/cuda/ActivationSoftshrinkKernel.cu | 60 ++ .../native/cuda/ActivationThresholdKernel.cu | 54 ++ .../ATen/native/cuda/BinaryDivFloorKernel.cu | 112 +++ .../ATen/native/cuda/BinaryDivTrueKernel.cu | 63 ++ .../ATen/native/cuda/BinaryDivTruncKernel.cu | 55 ++ aten/src/ATen/native/cuda/BinaryInternal.h | 48 ++ .../ATen/native/cuda/BinaryMulDivKernel.cu | 222 ------ aten/src/ATen/native/cuda/BinaryMulKernel.cu | 50 ++ .../ATen/native/cuda/ReduceAMinMaxKernel.cu | 51 ++ .../ATen/native/cuda/ReduceArgMaxKernel.cu | 48 ++ .../ATen/native/cuda/ReduceArgMinKernel.cu | 48 ++ .../ATen/native/cuda/ReduceMaxValuesKernel.cu | 63 ++ .../ATen/native/cuda/ReduceMinMaxKernel.cu | 168 ----- .../ATen/native/cuda/ReduceMinValuesKernel.cu | 58 ++ aten/src/ATen/native/cuda/Sort.cu | 216 ------ aten/src/ATen/native/cuda/Sort.h | 7 +- aten/src/ATen/native/cuda/SortStable.cu | 299 ++++++++ aten/src/ATen/native/cuda/SortStable.h | 19 + .../native/cuda/UnaryGeometricAcosKernel.cu | 56 ++ .../native/cuda/UnaryGeometricAcoshKernel.cu | 56 ++ .../native/cuda/UnaryGeometricAsinKernel.cu | 52 ++ .../native/cuda/UnaryGeometricAsinhKernel.cu | 56 ++ .../native/cuda/UnaryGeometricAtanKernel.cu | 56 ++ .../native/cuda/UnaryGeometricAtanhKernel.cu | 56 ++ .../native/cuda/UnaryGeometricCosKernel.cu | 55 ++ .../native/cuda/UnaryGeometricCoshKernel.cu | 56 ++ .../ATen/native/cuda/UnaryGeometricKernels.cu | 480 ------------ .../native/cuda/UnaryGeometricSinKernel.cu | 55 ++ .../native/cuda/UnaryGeometricSinhKernel.cu | 56 ++ .../native/cuda/UnaryGeometricTanKernel.cu | 55 ++ .../native/cuda/UnaryGeometricTanhKernel.cu | 56 ++ 48 files changed, 2952 insertions(+), 1914 deletions(-) create mode 100644 aten/src/ATen/cuda/cub-RadixSortKeys.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs.cu delete mode 100644 aten/src/ATen/native/cuda/Activation.cu create mode 100644 aten/src/ATen/native/cuda/ActivationEluKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationGeluKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationGluKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationHardswishKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationMishKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationPreluKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationSiluKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu create mode 100644 aten/src/ATen/native/cuda/ActivationThresholdKernel.cu create mode 100644 aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu create mode 100644 aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu create mode 100644 aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu create mode 100644 aten/src/ATen/native/cuda/BinaryInternal.h delete mode 100644 aten/src/ATen/native/cuda/BinaryMulDivKernel.cu create mode 100644 aten/src/ATen/native/cuda/BinaryMulKernel.cu create mode 100644 aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu create mode 100644 aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu create mode 100644 aten/src/ATen/native/cuda/ReduceArgMinKernel.cu create mode 100644 aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu delete mode 100644 aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu create mode 100644 aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu create mode 100644 aten/src/ATen/native/cuda/SortStable.cu create mode 100644 aten/src/ATen/native/cuda/SortStable.h create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu delete mode 100644 aten/src/ATen/native/cuda/UnaryGeometricKernels.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu create mode 100644 aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu diff --git a/aten/src/ATen/cuda/cub-RadixSortKeys.cu b/aten/src/ATen/cuda/cub-RadixSortKeys.cu new file mode 100644 index 0000000000000..3b585cf165cff --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortKeys.cu @@ -0,0 +1,108 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at { +namespace cuda { +namespace cub { + +template +void radix_sort_keys( + const key_t* keys_in, + key_t* keys_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending, + keys_in_, + keys_out_, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys, + keys_in_, + keys_out_, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } +} + +template +void unique( + const scalar_t* input, + scalar_t* output, + int64_t* num_selected_out, + int64_t num_items) { + TORCH_CHECK( + num_items <= std::numeric_limits::max(), + "cub unique does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, + input, + output, + num_selected_out, + num_items, + at::cuda::getCurrentCUDAStream()); +} + +template +void run_length_encode( + const scalar_t* input, + scalar_t* output, + int64_t* counts_out, + int64_t* length_out, + int64_t num_items) { + TORCH_CHECK( + num_items <= std::numeric_limits::max(), + "cub run_length_encode does not support more than INT_MAX elements"); + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, + input, + output, + counts_out, + length_out, + num_items, + at::cuda::getCurrentCUDAStream()); +} + +#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \ + template void radix_sort_keys( \ + const scalar_t* keys_in, \ + scalar_t* keys_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); \ + template void unique( \ + const scalar_t* input, \ + scalar_t* output, \ + int64_t* num_selected_out, \ + int64_t num_items); \ + template void run_length_encode( \ + const scalar_t* input, \ + scalar_t* output, \ + int64_t* counts_out, \ + int64_t* length_out, \ + int64_t n); + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) + +} // namespace cub +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs.cu b/aten/src/ATen/cuda/cub-RadixSortPairs.cu new file mode 100644 index 0000000000000..3c28a7141cf26 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs.cu @@ -0,0 +1,93 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include + +namespace at { +namespace cuda { +namespace cub { +namespace detail { + +template +void radix_sort_pairs_impl( + const key_t* keys_in, + key_t* keys_out, + const OpaqueType* values_in, + OpaqueType* values_out, + int64_t n, + bool descending, + int64_t begin_bit, + int64_t end_bit) { + TORCH_CHECK( + n <= std::numeric_limits::max(), + "cub sort does not support sorting more than INT_MAX elements"); + using key_t_ = typename detail::cuda_type::type; + + auto allocator = c10::cuda::CUDACachingAllocator::get(); + c10::DataPtr keys_out_owner; + + if (keys_out == nullptr) { + keys_out_owner = allocator->allocate(n * sizeof(key_t)); + keys_out = reinterpret_cast(keys_out_owner.get()); + } + + const key_t_* keys_in_ = reinterpret_cast(keys_in); + key_t_* keys_out_ = reinterpret_cast(keys_out); + + if (descending) { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } else { + CUB_WRAPPER( + NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, + keys_in_, + keys_out_, + values_in, + values_out, + n, + begin_bit, + end_bit, + c10::cuda::getCurrentCUDAStream()); + } +} + +#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ + template void radix_sort_pairs_impl( \ + const key_t* keys_in, \ + key_t* keys_out, \ + const OpaqueType* values_in, \ + OpaqueType* values_out, \ + int64_t n, \ + bool descending, \ + int64_t begin_bit, \ + int64_t end_bit); + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) +AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) +AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) +AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) + +#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ + AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) + +// BFloat16 Radix sort is supported from ROCm 4.5 onwards +#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500) +AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) +#endif + +} // namespace detail + +} // namespace cub +} // namespace cuda +} // namespace at diff --git a/aten/src/ATen/cuda/cub.cu b/aten/src/ATen/cuda/cub.cu index e1f87db1a295e..61aa7747e1999 100644 --- a/aten/src/ATen/cuda/cub.cu +++ b/aten/src/ATen/cuda/cub.cu @@ -5,118 +5,6 @@ namespace at { namespace cuda { namespace cub { -namespace detail { - -template -void radix_sort_pairs_impl( - const key_t *keys_in, key_t *keys_out, - const OpaqueType *values_in, OpaqueType *values_out, - int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) { - TORCH_CHECK(n <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); - using key_t_ = typename detail::cuda_type::type; - - auto allocator = c10::cuda::CUDACachingAllocator::get(); - c10::DataPtr keys_out_owner; - - if (keys_out == nullptr) { - keys_out_owner = allocator->allocate(n * sizeof(key_t)); - keys_out = reinterpret_cast(keys_out_owner.get()); - } - - const key_t_ *keys_in_ = reinterpret_cast(keys_in); - key_t_ *keys_out_ = reinterpret_cast(keys_out); - - if (descending) { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, - keys_in_, keys_out_, values_in, values_out, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } else { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, - keys_in_, keys_out_, values_in, values_out, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } -} - -#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ - template void radix_sort_pairs_impl( \ - const key_t *keys_in, key_t *keys_out, \ - const OpaqueType *values_in, \ - OpaqueType *values_out, \ - int64_t n, bool descending, int64_t begin_bit, int64_t end_bit); - -AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) - -#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ - AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) - -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) - -// BFloat16 Radix sort is supported from ROCm 4.5 onwards -#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500) -AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) -#endif - -} // namespace detail - -template -void radix_sort_keys( - const key_t *keys_in, key_t *keys_out, - int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) { - TORCH_CHECK(n <= std::numeric_limits::max(), - "cub sort does not support sorting more than INT_MAX elements"); - using key_t_ = typename detail::cuda_type::type; - - const key_t_ *keys_in_ = reinterpret_cast(keys_in); - key_t_ *keys_out_ = reinterpret_cast(keys_out); - - if (descending) { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending, - keys_in_, keys_out_, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } else { - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys, - keys_in_, keys_out_, n, - begin_bit, end_bit, c10::cuda::getCurrentCUDAStream()); - } -} - -template -void unique(const scalar_t *input, scalar_t *output, int64_t *num_selected_out, int64_t num_items) { - TORCH_CHECK(num_items <= std::numeric_limits::max(), - "cub unique does not support more than INT_MAX elements"); - CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, - input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream()); -} - -template -void run_length_encode(const scalar_t *input, scalar_t *output, int64_t *counts_out, - int64_t *length_out, int64_t num_items) { - TORCH_CHECK(num_items <= std::numeric_limits::max(), - "cub run_length_encode does not support more than INT_MAX elements"); - CUB_WRAPPER( - NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, - input, output, counts_out, length_out, num_items, - at::cuda::getCurrentCUDAStream()); -} - -#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \ - template void radix_sort_keys( \ - const scalar_t *keys_in, scalar_t *keys_out, int64_t n, \ - bool descending, int64_t begin_bit, int64_t end_bit); \ - template void unique( \ - const scalar_t *input, scalar_t *output, \ - int64_t *num_selected_out, int64_t num_items); \ - template void run_length_encode( \ - const scalar_t *input, scalar_t *output, int64_t *counts_out, \ - int64_t *length_out, int64_t n); - -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) namespace { template diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu deleted file mode 100644 index 999e0d41bda52..0000000000000 --- a/aten/src/ATen/native/cuda/Activation.cu +++ /dev/null @@ -1,710 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#define _USE_MATH_DEFINES - -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at { -namespace native { - -// ----------------------------------- -// glu forward -// ----------------------------------- -void glu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { - using opmath_t = at::opmath_type; - gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t { - const opmath_t a = a_; - const opmath_t b = b_; - const opmath_t one = opmath_t(1); - const opmath_t sigmoid = one / (one + std::exp(-b)); - return a * sigmoid; - }); - }); -} - -// ----------------------------------- -// glu forward ad -// ----------------------------------- -void glu_jvp_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { - using opmath_t = at::opmath_type; - gpu_kernel(iter, [] GPU_LAMBDA ( - scalar_t res_, - scalar_t b_, - scalar_t da_, - scalar_t db_) -> scalar_t { - const opmath_t res = res_; - const opmath_t b = b_; - const opmath_t da = da_; - const opmath_t db = db_; - const opmath_t one = opmath_t(1); - - const opmath_t sig_b = one / (one + std::exp(-b)); - return ( - da * sig_b + res * (db - sig_b * db) - ); - }); - }); -} - -// ----------------------------------- -// glu backward -// ----------------------------------- - -// Byte offsets don't require multiplication by sizeof(T), so are slightly cheaper. -// For fixed offsets, this removes all penalty from 64-bit indexing. -template -__device__ T* byte_offset(T* ptr, int64_t offset) { - using byte_ptr_t = typename std::conditional< - std::is_const::value, const char*, char*>::type; - return reinterpret_cast( - reinterpret_cast(ptr) + offset - ); -} - -template -__global__ void glu_backward_kernel( - int numel, scalar_t* gI, const scalar_t* I, const scalar_t* gO, - OffsetCalc offset_calculator, - int64_t gI_byte_offset, int64_t I_byte_offset) { - using opmath_t = at::opmath_type; - - const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; - if (linear_index >= numel) { - return; - } - const auto offsets = offset_calculator.get(linear_index); - - // We explicitly iterate over the first half of the input tensor, and - // gI_byte_offset and I_byte_offset are the offsets to access the - // corresponding index in the second half of the tensor. - const opmath_t a = I[offsets[1]]; - const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset); - const opmath_t gO_val = gO[offsets[2]]; - - const auto one = opmath_t(1); - const opmath_t sigmoid = one / (one + std::exp(-b)); - - auto* gA = gI + offsets[0]; - *gA = sigmoid * gO_val; - - auto* gB = byte_offset(gA, gI_byte_offset); - *gB = (one - sigmoid) * sigmoid * gO_val * a; -} - -void launch_glu_backward_kernel(const TensorIteratorBase& iter, - int64_t gI_stride, int64_t I_stride) { - const auto N = iter.numel(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(N > 0 && N <= std::numeric_limits::max()); - const auto offset_calculator = make_element_offset_calculator<3>(iter); - constexpr int64_t block_size = 256; - const int64_t grid = (N + block_size - 1) / block_size; - const auto stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "glu_backward_cuda", [&] { - auto gI = static_cast(iter.data_ptr(0)); - auto I = static_cast(iter.data_ptr(1)); - auto gO = static_cast(iter.data_ptr(2)); - glu_backward_kernel<<>>( - N, gI, I, gO, offset_calculator, - gI_stride * sizeof(scalar_t), I_stride * sizeof(scalar_t)); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// ----------------------------------- -// log_sigmoid forward -// ----------------------------------- - -void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), - "log_sigmoid_forward_cuda", [&] { - using opmath_t = at::opmath_type; - - gpu_kernel(iter, - [] GPU_LAMBDA (scalar_t in_) -> scalar_t { - const opmath_t in = in_; - const auto min = std::min(opmath_t(0), in); - const auto z = std::exp(-std::abs(in)); - return min - std::log1p(z); - }); - }); -} - -// ----------------------------------- -// log_sigmoid backward -// ----------------------------------- - -void log_sigmoid_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), - "log_sigmoid_backward_cuda", [&] { - using opmath_t = at::opmath_type; - gpu_kernel(iter, - [] GPU_LAMBDA (scalar_t in_, scalar_t grad_out_) -> scalar_t { - const opmath_t in = in_; - const opmath_t grad_out = grad_out_; - - auto in_negative = in < opmath_t(0); - auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0); - auto sign = in_negative ? opmath_t(1) : -opmath_t(1); - const auto z = std::exp(-std::abs(in)); - return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z))); - }); - }); -} - -// ----------------------------------- -// prelu forward -// ----------------------------------- -void launch_prelu_cuda_kernel_share_weights(TensorIteratorBase &iter, const TensorBase &weight) { - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_cuda", [&] { - const auto *weight_data = weight.data_ptr(); - at::native::gpu_kernel(iter, - [weight_data] GPU_LAMBDA (scalar_t input_val) { - return (input_val > 0) ? input_val : *weight_data * input_val; - }); - }); -} - -template -__global__ void prelu_cuda_kernel_multi_weights( - scalar_t* result_data, - const scalar_t* input_data, - const scalar_t* weight_data, - int64_t input_stride0, - int64_t input_stride1, - int64_t input_numel) { - - int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; - if (linearId >= input_numel) return; - - // multiply values at each channel with weight[channel_index] - int64_t channel = (linearId % input_stride0) / input_stride1; - scalar_t input_data_val = input_data[linearId]; - result_data[linearId] = (input_data_val > 0) ? input_data_val : weight_data[channel] * input_data_val; -} - -void launch_prelu_cuda_kernel_multi_weights( - const TensorBase &result, const TensorBase &input, const TensorBase &weight) { - int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - int64_t input_stride0 = 1, input_stride1 = 1; - - if (input_ndim > 1) { - channel_size = input.size(1); // channel is the 2nd dim of input - auto strides = input.strides(); - input_stride0 = strides[0]; - input_stride1 = strides[1]; - } - const int64_t weight_num = weight.numel(); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // config to run cuda kernel - int64_t input_numel = input.numel(); - const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions"); - - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] { - prelu_cuda_kernel_multi_weights - <<>>( - result.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - input_stride0, - input_stride1, - input_numel); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// ----------------------------------- -// prelu backward -// ----------------------------------- -void launch_prelu_cuda_backward_kernel_share_weights( - TensorIteratorBase &iter, const TensorBase &weight) { - // N.B. `std::tuple` does not support `::operator=` on device code. - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_backward_cuda", [&] { - const auto *weight_data = weight.data_ptr(); - gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple { - scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out; - scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out; - return {input_grad, weight_grad_collector}; - }); - }); -} - -template -__global__ void prelu_cuda_backward_kernel_multi_weights( - const scalar_t* input_data, - const scalar_t* weight_data, - const scalar_t* grad_out_data, - scalar_t* input_grad_data, - scalar_t* weight_grad_collector, - int64_t input_stride0, - int64_t input_stride1, - int64_t input_numel) { - - int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; - if (linearId >= input_numel) return; - int64_t channel = (linearId % input_stride0) / input_stride1; - scalar_t input_data_val = input_data[linearId]; - scalar_t grad_out_data_val = grad_out_data[linearId]; - input_grad_data[linearId] = (input_data_val > 0) ? grad_out_data_val : weight_data[channel] * grad_out_data_val; - weight_grad_collector[linearId] = (input_data_val > 0) ? scalar_t(0) : input_data_val * grad_out_data_val; -} - -void launch_prelu_cuda_backward_kernel_multi_weights( - const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, - const TensorBase &input_grad, const TensorBase &weight_grad_collector) { - int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); - - int64_t channel_size = 1; // channel_size default to 1 - int64_t input_stride0 = 1, input_stride1 = 1; - - if (input_ndim > 1) { - channel_size = input.size(1); // channel is the 2nd dim of input - auto strides = input.strides(); - input_stride0 = strides[0]; - input_stride1 = strides[1]; - } - const int64_t weight_num = weight.numel(); - TORCH_CHECK(channel_size == weight_num, - "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, - " and channel size = ", channel_size, "."); - - // config to run cuda kernel - int64_t input_numel = input.numel(); - const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); - dim3 grid; - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions"); - - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] { - prelu_cuda_backward_kernel_multi_weights - <<>>( - input.data_ptr(), - weight.data_ptr(), - grad_out.data_ptr(), - input_grad.data_ptr(), - weight_grad_collector.data_ptr(), - input_stride0, - input_stride1, - input_numel); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -// ----------------------------------- -// hardshrink -// ----------------------------------- -void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardshrink_cuda", [&]() { - auto lambd = value.to(); - gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t { - return (a >= -lambd && a <= lambd) ? scalar_t(0) : a; - }); - }); -} - -void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() { - auto lambd = value.to(); - gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t { - return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)); - }); - }); -} - -void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "shrink_backward_cuda", [&]() { - auto lambd = value.to(); - gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t { - return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val; - }); - }); -} - -void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) { - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, - iter.dtype(), "hardtanh_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto min_val = min.to(); - auto max_val = max.to(); - gpu_kernel(iter, [min_val, max_val]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop; - }); - }); -} - -void softplus_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "softplus_cuda", [&]() { - using opmath_t = at::opmath_type; - auto beta = beta_.to(); - auto threshold = threshold_.to(); - gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t { - opmath_t aop = static_cast(a); - return (aop * beta) > threshold ? aop : (::log1p(std::exp(aop * beta))) / beta; - }); - }); -} - -void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, const Scalar& threshold_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "softplus_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto beta = beta_.to(); - auto threshold = threshold_.to(); - gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - opmath_t z = std::exp(bop * beta); - return (bop * beta) > threshold ? aop : aop * z / (z + opmath_t(1.)); - }); - }); -} - -template -void threshold_kernel_impl(TensorIteratorBase& iter, scalar_t threshold, scalar_t value) { - gpu_kernel_with_scalars(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t { - return x <= threshold ? value : other; - }); -} - -static void threshold_kernel_cuda(TensorIteratorBase& iter, const Scalar& threshold, const Scalar& value) { - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "threshold_cuda", [&] { - threshold_kernel_impl(iter, threshold.to(), value.to()); - }); -} - -void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negcoef = alpha.to() * scale.to(); - auto poscoef = scale.to(); - auto negiptcoef = input_scale.to(); - gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t { - opmath_t aop = static_cast(a); - return aop > 0 ? aop * poscoef : std::expm1(aop * negiptcoef) * negcoef; - }); - }); -} - -void elu_backward_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, bool is_result) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negcoef = alpha.to() * scale.to(); - auto poscoef = scale.to(); - auto negiptcoef = input_scale.to(); - gpu_kernel(iter, [negcoef, poscoef, negiptcoef, is_result]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - - if (is_result) { - return bop <= 0 ? aop * negiptcoef * (bop + negcoef) : aop * poscoef; - } else { - return bop <= 0 ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef) : aop * poscoef; - } - }); - }); -} - -void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { - if (approximate == GeluType::Tanh) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); - constexpr opmath_t kKappa = 0.044715; - auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); - auto inner = kBeta * (static_cast(x) + kKappa * x_cube); - return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kAlpha = M_SQRT1_2; - return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); - }); - }); - } -} - -void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { - if (approximate == GeluType::Tanh) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); - constexpr opmath_t kKappa = 0.044715; - auto x_sq = static_cast(x) * static_cast(x); - auto x_cube = x_sq * static_cast(x); - auto inner = kBeta * (static_cast(x) + kKappa * x_cube); - auto tanh_inner = c10::cuda::compat::tanh(inner); - - auto left = opmath_t(0.5) * static_cast(x); - auto right = opmath_t(1) + tanh_inner; - - auto left_derivative = 0.5 * right; - - auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; - auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); - auto right_derivative = left * tanh_derivative * inner_derivative; - - return static_cast(dy) * (left_derivative + right_derivative); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { - gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); - constexpr opmath_t kAlpha = M_SQRT1_2; - const opmath_t cdf = - opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); - const opmath_t pdf = - c10::cuda::compat::exp( - opmath_t(-0.5) * static_cast(x) * static_cast(x)) * - kBeta; - return static_cast(dy) * (cdf + static_cast(x) * pdf); - }); - }); - } -} - -namespace { - -void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "leaky_relu_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negval = negval_.to(); - gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t { - opmath_t aop = static_cast(a); - return aop > opmath_t(0) ? aop : aop * negval; - }); - }); -} - -void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "leaky_relu_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - auto negval = negval_.to(); - gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { - opmath_t aop = static_cast(a); - opmath_t bop = static_cast(b); - return aop > opmath_t(0) ? bop : bop * negval; - }); - }); -} - -void hardswish_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t one_sixth(1.0f / 6.0f); - const opmath_t three(3.0f); - const opmath_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - opmath_t x = static_cast(self_val); - return x * std::min(std::max(x + three, zero), six) * one_sixth; - }); - }); -} - -void hardswish_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t three(3.0f); - const opmath_t neg_three(-3.0f); - const opmath_t one_half(0.5f); - gpu_kernel( - iter, - [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { - opmath_t grad_val = static_cast(grad_val_); - opmath_t self_val = static_cast(self_val_); - if (self_val < neg_three) { - return zero; - } else if (self_val <= three) { - return grad_val * ((self_val / three) + one_half); - } else { - return grad_val; - } - }); - }); -} - -void hardsigmoid_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - iter.dtype(), "hardsigmoid_cuda", [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t one_sixth(1.0f / 6.0f); - const opmath_t three(3.0f); - const opmath_t six(6.0f); - gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { - opmath_t x = static_cast(self_val); - return std::min(std::max(x + three, zero), six) * one_sixth; - }); - }); -} - -void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "hardsigmoid_backward_cuda", - [&]() { - using opmath_t = at::opmath_type; - const opmath_t zero(0.0f); - const opmath_t three(3.0f); - const opmath_t neg_three(-3.0f); - const opmath_t one_sixth(1.0f / 6.0f); - gpu_kernel( - iter, - [zero, three, neg_three, one_sixth]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { - opmath_t grad_val = static_cast(grad_val_); - opmath_t self_val = static_cast(self_val_); - return (self_val > neg_three && self_val < three) - ? grad_val * one_sixth - : zero; - }); - }); -} - -void silu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "silu_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t x_acc = static_cast(x); - return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); - }); - }); -} - -void silu_backward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "silu_backward_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t dy_acc = static_cast(dy); - const opmath_t x_acc = static_cast(x); - const opmath_t s_acc = - opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); - return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc)); - }); - }); -} - -void mish_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "mish_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t x_acc = static_cast(x); - return x_acc * c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); - }); - }); -} - -void mish_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - iter.dtype(), - "mish_backward_cuda", - [&]() { - gpu_kernel( - iter, - [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { - using opmath_t = at::opmath_type; - const opmath_t dy_acc = static_cast(dy); - const opmath_t x_acc = static_cast(x); - const opmath_t s_acc = - opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); - const opmath_t t_acc = - c10::cuda::compat::tanh(c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); - return dy_acc * (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); - }); - }); -} - -} // namespace - -REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel); -REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel); -REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel); -REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel); -REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel); -REGISTER_DISPATCH(elu_stub, &elu_kernel); -REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel); -REGISTER_DISPATCH(glu_stub, &glu_kernel); -REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel); -REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); -REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); -REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel); -REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel); -REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); -REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); -REGISTER_DISPATCH(softplus_stub, &softplus_kernel); -REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); -REGISTER_DISPATCH(silu_stub, &silu_kernel); -REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); -REGISTER_DISPATCH(mish_stub, &mish_kernel); -REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); -REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda); - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu new file mode 100644 index 0000000000000..113e6da10eacd --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -0,0 +1,88 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void elu_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > 0 ? aop * poscoef + : std::expm1(aop * negiptcoef) * negcoef; + }); + }); +} + +void elu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "elu_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negcoef = alpha.to() * scale.to(); + auto poscoef = scale.to(); + auto negiptcoef = input_scale.to(); + gpu_kernel( + iter, + [negcoef, poscoef, negiptcoef, is_result] GPU_LAMBDA( + scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + + if (is_result) { + return bop <= 0 ? aop * negiptcoef * (bop + negcoef) + : aop * poscoef; + } else { + return bop <= 0 + ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef) + : aop * poscoef; + } + }); + }); +} +} // namespace + +REGISTER_DISPATCH(elu_stub, &elu_kernel); +REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu new file mode 100644 index 0000000000000..d3d7879d3b884 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -0,0 +1,90 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_cube = static_cast(x) * static_cast(x) * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + return opmath_t(0.5) * static_cast(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kAlpha = M_SQRT1_2; + return static_cast(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + }); + }); + } +} + +void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { + if (approximate == GeluType::Tanh) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); + constexpr opmath_t kKappa = 0.044715; + auto x_sq = static_cast(x) * static_cast(x); + auto x_cube = x_sq * static_cast(x); + auto inner = kBeta * (static_cast(x) + kKappa * x_cube); + auto tanh_inner = c10::cuda::compat::tanh(inner); + + auto left = opmath_t(0.5) * static_cast(x); + auto right = opmath_t(1) + tanh_inner; + + auto left_derivative = 0.5 * right; + + auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; + auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); + auto right_derivative = left * tanh_derivative * inner_derivative; + + return static_cast(dy) * (left_derivative + right_derivative); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { + gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); + constexpr opmath_t kAlpha = M_SQRT1_2; + const opmath_t cdf = + opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast(x) * kAlpha)); + const opmath_t pdf = + c10::cuda::compat::exp( + opmath_t(-0.5) * static_cast(x) * static_cast(x)) * + kBeta; + return static_cast(dy) * (cdf + static_cast(x) * pdf); + }); + }); + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu new file mode 100644 index 0000000000000..740edbbf38ee2 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -0,0 +1,143 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// ----------------------------------- +// glu forward +// ----------------------------------- +void glu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a_, scalar_t b_) -> scalar_t { + const opmath_t a = a_; + const opmath_t b = b_; + const opmath_t one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + return a * sigmoid; + }); + }); +} + +// ----------------------------------- +// glu forward ad +// ----------------------------------- +void glu_jvp_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, + [] GPU_LAMBDA( + scalar_t res_, scalar_t b_, scalar_t da_, scalar_t db_) + -> scalar_t { + const opmath_t res = res_; + const opmath_t b = b_; + const opmath_t da = da_; + const opmath_t db = db_; + const opmath_t one = opmath_t(1); + + const opmath_t sig_b = one / (one + std::exp(-b)); + return (da * sig_b + res * (db - sig_b * db)); + }); + }); +} + +// ----------------------------------- +// glu backward +// ----------------------------------- + +// Byte offsets don't require multiplication by sizeof(T), so are slightly +// cheaper. For fixed offsets, this removes all penalty from 64-bit indexing. +template +__device__ T* byte_offset(T* ptr, int64_t offset) { + using byte_ptr_t = typename std:: + conditional::value, const char*, char*>::type; + return reinterpret_cast(reinterpret_cast(ptr) + offset); +} + +template +__global__ void glu_backward_kernel( + int numel, + scalar_t* gI, + const scalar_t* I, + const scalar_t* gO, + OffsetCalc offset_calculator, + int64_t gI_byte_offset, + int64_t I_byte_offset) { + using opmath_t = at::opmath_type; + + const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_index >= numel) { + return; + } + const auto offsets = offset_calculator.get(linear_index); + + // We explicitly iterate over the first half of the input tensor, and + // gI_byte_offset and I_byte_offset are the offsets to access the + // corresponding index in the second half of the tensor. + const opmath_t a = I[offsets[1]]; + const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset); + const opmath_t gO_val = gO[offsets[2]]; + + const auto one = opmath_t(1); + const opmath_t sigmoid = one / (one + std::exp(-b)); + + auto* gA = gI + offsets[0]; + *gA = sigmoid * gO_val; + + auto* gB = byte_offset(gA, gI_byte_offset); + *gB = (one - sigmoid) * sigmoid * gO_val * a; +} + +void launch_glu_backward_kernel( + const TensorIteratorBase& iter, + int64_t gI_stride, + int64_t I_stride) { + const auto N = iter.numel(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + N > 0 && N <= std::numeric_limits::max()); + const auto offset_calculator = make_element_offset_calculator<3>(iter); + constexpr int64_t block_size = 256; + const int64_t grid = (N + block_size - 1) / block_size; + const auto stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "glu_backward_cuda", [&] { + auto gI = static_cast(iter.data_ptr(0)); + auto I = static_cast(iter.data_ptr(1)); + auto gO = static_cast(iter.data_ptr(2)); + glu_backward_kernel<<>>( + N, + gI, + I, + gO, + offset_calculator, + gI_stride * sizeof(scalar_t), + I_stride * sizeof(scalar_t)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +REGISTER_DISPATCH(glu_stub, &glu_kernel); +REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu new file mode 100644 index 0000000000000..ae2f6b11b8523 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -0,0 +1,41 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardshrink_cuda", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return (a >= -lambd && a <= lambd) ? scalar_t(0) : a; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu new file mode 100644 index 0000000000000..ceafa53b72f1c --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -0,0 +1,76 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardsigmoid_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_cuda", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel( + iter, + [zero, one_sixth, three, six] GPU_LAMBDA( + scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardsigmoid_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "hardsigmoid_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_sixth(1.0f / 6.0f); + gpu_kernel( + iter, + [zero, three, neg_three, one_sixth] GPU_LAMBDA( + scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + return (self_val > neg_three && self_val < three) + ? grad_val * one_sixth + : zero; + }); + }); +} + +} // namespace + +REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel); +REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu new file mode 100644 index 0000000000000..7d952043ad872 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -0,0 +1,65 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardswish_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t one_sixth(1.0f / 6.0f); + const opmath_t three(3.0f); + const opmath_t six(6.0f); + gpu_kernel(iter, [zero, one_sixth, three, six]GPU_LAMBDA(scalar_t self_val) -> scalar_t { + opmath_t x = static_cast(self_val); + return x * std::min(std::max(x + three, zero), six) * one_sixth; + }); + }); +} + +void hardswish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_backward_cuda", [&]() { + using opmath_t = at::opmath_type; + const opmath_t zero(0.0f); + const opmath_t three(3.0f); + const opmath_t neg_three(-3.0f); + const opmath_t one_half(0.5f); + gpu_kernel( + iter, + [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t { + opmath_t grad_val = static_cast(grad_val_); + opmath_t self_val = static_cast(self_val_); + if (self_val < neg_three) { + return zero; + } else if (self_val <= three) { + return grad_val * ((self_val / three) + one_half); + } else { + return grad_val; + } + }); + }); +} +} // namespace + +REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel); +REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu new file mode 100644 index 0000000000000..1ef3fdba2898f --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -0,0 +1,46 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void hardtanh_backward_kernel( + TensorIterator& iter, + const Scalar& min, + const Scalar& max) { + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::Half, iter.dtype(), "hardtanh_backward_cuda", [&]() { + using opmath_t = at::opmath_type; + auto min_val = min.to(); + auto max_val = max.to(); + gpu_kernel( + iter, + [min_val, max_val] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return (bop <= min_val) || (bop >= max_val) ? opmath_t(0) : aop; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu new file mode 100644 index 0000000000000..c323aca1ca7fb --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -0,0 +1,64 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel(iter, [negval] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return aop > opmath_t(0) ? aop : aop * negval; + }); + }); +} + +void leaky_relu_backward_kernel( + TensorIteratorBase& iter, + const Scalar& negval_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "leaky_relu_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto negval = negval_.to(); + gpu_kernel( + iter, [negval] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + return aop > opmath_t(0) ? bop : bop * negval; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel); +REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu new file mode 100644 index 0000000000000..131462467d3dd --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -0,0 +1,66 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// ----------------------------------- +// log_sigmoid forward +// ----------------------------------- + +void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_cuda", [&] { + using opmath_t = at::opmath_type; + + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t { + const opmath_t in = in_; + const auto min = std::min(opmath_t(0), in); + const auto z = std::exp(-std::abs(in)); + return min - std::log1p(z); + }); + }); +} + +namespace { +// ----------------------------------- +// log_sigmoid backward +// ----------------------------------- +void log_sigmoid_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_backward_cuda", [&] { + using opmath_t = at::opmath_type; + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t in_, scalar_t grad_out_) -> scalar_t { + const opmath_t in = in_; + const opmath_t grad_out = grad_out_; + + auto in_negative = in < opmath_t(0); + auto max_deriv = in_negative ? opmath_t(1) : opmath_t(0); + auto sign = in_negative ? opmath_t(1) : -opmath_t(1); + const auto z = std::exp(-std::abs(in)); + return grad_out * (max_deriv - sign * (z / (opmath_t(1) + z))); + }); + }); +} +} // namespace + +REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu new file mode 100644 index 0000000000000..70c058644f666 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -0,0 +1,66 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void mish_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc * + c10::cuda::compat::tanh( + c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); + }); + }); +} + +void mish_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "mish_backward_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + const opmath_t t_acc = c10::cuda::compat::tanh( + c10::cuda::compat::log1p(c10::cuda::compat::exp(x_acc))); + return dy_acc * + (t_acc + x_acc * s_acc * (opmath_t(1) - t_acc * t_acc)); + }); + }); +} +} // namespace + +REGISTER_DISPATCH(mish_stub, &mish_kernel); +REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationPreluKernel.cu b/aten/src/ATen/native/cuda/ActivationPreluKernel.cu new file mode 100644 index 0000000000000..0d8f09714698e --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationPreluKernel.cu @@ -0,0 +1,175 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +// ----------------------------------- +// prelu forward +// ----------------------------------- +void launch_prelu_cuda_kernel_share_weights(TensorIteratorBase &iter, const TensorBase &weight) { + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_cuda", [&] { + const auto *weight_data = weight.data_ptr(); + at::native::gpu_kernel(iter, + [weight_data] GPU_LAMBDA (scalar_t input_val) { + return (input_val > 0) ? input_val : *weight_data * input_val; + }); + }); +} + +template +__global__ void prelu_cuda_kernel_multi_weights( + scalar_t* result_data, + const scalar_t* input_data, + const scalar_t* weight_data, + int64_t input_stride0, + int64_t input_stride1, + int64_t input_numel) { + + int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; + if (linearId >= input_numel) return; + + // multiply values at each channel with weight[channel_index] + int64_t channel = (linearId % input_stride0) / input_stride1; + scalar_t input_data_val = input_data[linearId]; + result_data[linearId] = (input_data_val > 0) ? input_data_val : weight_data[channel] * input_data_val; +} + +void launch_prelu_cuda_kernel_multi_weights( + const TensorBase &result, const TensorBase &input, const TensorBase &weight) { + int64_t input_ndim = input.dim(); + TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + int64_t input_stride0 = 1, input_stride1 = 1; + + if (input_ndim > 1) { + channel_size = input.size(1); // channel is the 2nd dim of input + auto strides = input.strides(); + input_stride0 = strides[0]; + input_stride1 = strides[1]; + } + const int64_t weight_num = weight.numel(); + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + // config to run cuda kernel + int64_t input_numel = input.numel(); + const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); + dim3 grid; + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu: input too large or too many dimensions"); + + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_cuda", [&] { + prelu_cuda_kernel_multi_weights + <<>>( + result.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + input_stride0, + input_stride1, + input_numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +// ----------------------------------- +// prelu backward +// ----------------------------------- +void launch_prelu_cuda_backward_kernel_share_weights( + TensorIteratorBase &iter, const TensorBase &weight) { + // N.B. `std::tuple` does not support `::operator=` on device code. + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, iter.input_dtype(), "prelu_backward_cuda", [&] { + const auto *weight_data = weight.data_ptr(); + gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t input, scalar_t grad_out) -> thrust::tuple { + scalar_t input_grad = input > 0 ? grad_out : (*weight_data) * grad_out; + scalar_t weight_grad_collector = input > 0 ? scalar_t(0) : input * grad_out; + return {input_grad, weight_grad_collector}; + }); + }); +} + +template +__global__ void prelu_cuda_backward_kernel_multi_weights( + const scalar_t* input_data, + const scalar_t* weight_data, + const scalar_t* grad_out_data, + scalar_t* input_grad_data, + scalar_t* weight_grad_collector, + int64_t input_stride0, + int64_t input_stride1, + int64_t input_numel) { + + int64_t linearId = blockIdx.x * blockDim.x + threadIdx.x; + if (linearId >= input_numel) return; + int64_t channel = (linearId % input_stride0) / input_stride1; + scalar_t input_data_val = input_data[linearId]; + scalar_t grad_out_data_val = grad_out_data[linearId]; + input_grad_data[linearId] = (input_data_val > 0) ? grad_out_data_val : weight_data[channel] * grad_out_data_val; + weight_grad_collector[linearId] = (input_data_val > 0) ? scalar_t(0) : input_data_val * grad_out_data_val; +} + +void launch_prelu_cuda_backward_kernel_multi_weights( + const TensorBase &input, const TensorBase &weight, const TensorBase &grad_out, + const TensorBase &input_grad, const TensorBase &weight_grad_collector) { + int64_t input_ndim = input.dim(); + TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor."); + + int64_t channel_size = 1; // channel_size default to 1 + int64_t input_stride0 = 1, input_stride1 = 1; + + if (input_ndim > 1) { + channel_size = input.size(1); // channel is the 2nd dim of input + auto strides = input.strides(); + input_stride0 = strides[0]; + input_stride1 = strides[1]; + } + const int64_t weight_num = weight.numel(); + TORCH_CHECK(channel_size == weight_num, + "Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num, + " and channel size = ", channel_size, "."); + + // config to run cuda kernel + int64_t input_numel = input.numel(); + const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), input_numel)); + dim3 grid; + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + TORCH_CHECK(cuda::getApplyGrid(input_numel, grid, curDevice), "prelu_backward_cuda: input too large or too many dimensions"); + + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "prelu_backward_cuda", [&] { + prelu_cuda_backward_kernel_multi_weights + <<>>( + input.data_ptr(), + weight.data_ptr(), + grad_out.data_ptr(), + input_grad.data_ptr(), + weight_grad_collector.data_ptr(), + input_stride0, + input_stride1, + input_numel); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu new file mode 100644 index 0000000000000..701b901e4f773 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -0,0 +1,61 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void silu_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t x_acc = static_cast(x); + return x_acc / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + }); + }); +} + +void silu_backward_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "silu_backward_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { + using opmath_t = at::opmath_type; + const opmath_t dy_acc = static_cast(dy); + const opmath_t x_acc = static_cast(x); + const opmath_t s_acc = + opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc)); + return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc)); + }); + }); +} +} // namespace + +REGISTER_DISPATCH(silu_stub, &silu_kernel); +REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu new file mode 100644 index 0000000000000..86c04221b24f0 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -0,0 +1,76 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void softplus_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel(iter, [beta, threshold] GPU_LAMBDA(scalar_t a) -> scalar_t { + opmath_t aop = static_cast(a); + return (aop * beta) > threshold + ? aop + : (::log1p(std::exp(aop * beta))) / beta; + }); + }); +} + +void softplus_backward_kernel( + TensorIteratorBase& iter, + const Scalar& beta_, + const Scalar& threshold_) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softplus_backward_cuda", + [&]() { + using opmath_t = at::opmath_type; + auto beta = beta_.to(); + auto threshold = threshold_.to(); + gpu_kernel( + iter, + [beta, threshold] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + opmath_t aop = static_cast(a); + opmath_t bop = static_cast(b); + opmath_t z = std::exp(bop * beta); + return (bop * beta) > threshold ? aop + : aop * z / (z + opmath_t(1.)); + }); + }); +} + +} // namespace + +REGISTER_DISPATCH(softplus_stub, &softplus_kernel); +REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu new file mode 100644 index 0000000000000..e21e3b94fac48 --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -0,0 +1,60 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +void softshrink_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "softshrink_cuda", + [&]() { + auto lambd = value.to(); + gpu_kernel(iter, [lambd] GPU_LAMBDA(scalar_t a) -> scalar_t { + return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0)); + }); + }); +} + +void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& value) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "shrink_backward_cuda", + [&]() { + auto lambd = value.to(); + gpu_kernel( + iter, + [lambd] GPU_LAMBDA( + scalar_t grad_val, scalar_t self_val) -> scalar_t { + return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) + : grad_val; + }); + }); +} +} // namespace + +REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel); +REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu new file mode 100644 index 0000000000000..86d8bbd528c8f --- /dev/null +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -0,0 +1,54 @@ +#define TORCH_ASSERT_NO_OPERATORS +#define _USE_MATH_DEFINES + +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace { + +template +void threshold_kernel_impl( + TensorIteratorBase& iter, + scalar_t threshold, + scalar_t value) { + gpu_kernel_with_scalars( + iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t { + return x <= threshold ? value : other; + }); +} + +static void threshold_kernel_cuda( + TensorIteratorBase& iter, + const Scalar& threshold, + const Scalar& value) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.dtype(), + "threshold_cuda", + [&] { + threshold_kernel_impl( + iter, threshold.to(), value.to()); + }); +} + +} // namespace + +REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu b/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu new file mode 100644 index 0000000000000..13e9757b5f39d --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu @@ -0,0 +1,112 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +void div_floor_kernel_cuda(TensorIteratorBase& iter) { + // See NOTE: [Floor Division in Python] + const auto dtype = iter.common_dtype(); + if (dtype == kByte) { + // In the special case of unsigned integer division, floor division is + // equivalent to truncation division (since the signs of the divisor and + // dividend are always the same) + return div_trunc_kernel_cuda(iter); + } else if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cuda", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + if (c10::signs_differ(a, b)) { + // Subtracts one from the results of truncation division if the + // divisor and dividend have different sign(bit)s and the + // remainder of the division is nonzero + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + + return a / b; + }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { + using accscalar_t = at::acc_type; + auto b = iter.scalar_value(2); + if (C10_UNLIKELY(b == 0)) { + return div_true_kernel_cuda(iter); + } + + auto inv_b = accscalar_t(1.0) / b; + iter.remove_operand(2); + gpu_kernel(iter, [b, inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + auto mod = std::fmod(a, b); + auto div = (a - mod) * inv_b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = c10::cuda::compat::copysign(scalar_t(0), a * inv_b); + } + return floordiv; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + if (C10_UNLIKELY(b == 0)) { + return a / b; + } + + auto mod = std::fmod(a, b); + auto div = (a - mod) / b; + if ((mod != 0) && (b < 0) != (mod < 0)) { + div -= scalar_t(1); + } + + scalar_t floordiv; + if (div != 0) { + floordiv = std::floor(div); + if (div - floordiv > scalar_t(0.5)) { + floordiv += scalar_t(1.0); + } + } else { + floordiv = c10::cuda::compat::copysign(scalar_t(0), a / b); + } + return floordiv; + }); + }); + } +} +} // namespace binary_internal + +REGISTER_DISPATCH(div_floor_stub, &binary_internal::div_floor_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu new file mode 100644 index 0000000000000..642318d2239fb --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu @@ -0,0 +1,63 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +const char div_name[] = "div_kernel"; +void div_true_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (iter.common_dtype() == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto div_string = jiterator_stringify( + template T div_kernel(T a, T b) { return a / b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, div_string); +#else + using opmath_t = at::opmath_type; + opmath_gpu_kernel_with_scalars(iter, DivFunctor()); +#endif + return; + } + if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { + using opmath_t = at::opmath_type; + auto inv_b = opmath_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel( + iter, + BUnaryFunctor>( + MulFunctor(), inv_b)); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { + DivFunctor f; + gpu_kernel_with_scalars(iter, f); + }); + } +} +} // namespace binary_internal + +REGISTER_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu b/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu new file mode 100644 index 0000000000000..01a04b40cbc1a --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +void div_trunc_kernel_cuda(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (isIntegralType(dtype, /*includeBool*/ false)) { + AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() { + gpu_kernel_with_scalars( + iter, + [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a / b; }); + }); + } else if (iter.is_cpu_scalar(2)) { + // optimization for floating-point types: if the second operand is a CPU + // scalar, compute a * reciprocal(b). Note that this may lose one bit of + // precision compared to computing the division. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { + using accscalar_t = at::acc_type; + auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); + iter.remove_operand(2); + gpu_kernel(iter, [inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t { + return std::trunc(a * inv_b); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { + gpu_kernel_with_scalars( + iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { + return std::trunc(a / b); + }); + }); + } +} +} // namespace binary_internal + +REGISTER_DISPATCH(div_trunc_stub, &binary_internal::div_trunc_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryInternal.h b/aten/src/ATen/native/cuda/BinaryInternal.h new file mode 100644 index 0000000000000..e098d32b114d6 --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryInternal.h @@ -0,0 +1,48 @@ +// DON'T include this except from Binary*.cu files. It should not leak into +// headers. +#pragma once +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { +namespace native { +namespace binary_internal { + +template +struct DivFunctor { + __device__ scalar_t operator()(scalar_t a, scalar_t b) const { + return a / b; + } +}; + +template +struct MulFunctor { + __device__ T operator()(T a, T b) const { + return a * b; + } +}; + +// Workaround for the error: '*' in boolean context, suggest '&&' instead +// [-Werror=int-in-bool-context] +template <> +struct MulFunctor { + __device__ bool operator()(bool a, bool b) const { + return a && b; + } +}; +void div_true_kernel_cuda(TensorIteratorBase& iter); +void div_trunc_kernel_cuda(TensorIteratorBase& iter); +} // namespace binary_internal +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu b/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu deleted file mode 100644 index f3998bf9c2cd9..0000000000000 --- a/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu +++ /dev/null @@ -1,222 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -// NOTE: CUDA on Windows requires that the enclosing function -// of a __device__ lambda not have internal linkage. - -namespace at { namespace native { - -template -struct DivFunctor { - __device__ scalar_t operator() (scalar_t a, scalar_t b) const { - return a / b; - } -}; - -template -struct MulFunctor { - __device__ T operator() (T a, T b) const { - return a * b; - } -}; - -// Workaround for the error: '*' in boolean context, suggest '&&' instead [-Werror=int-in-bool-context] -template<> -struct MulFunctor { - __device__ bool operator() (bool a, bool b) const { - return a && b; - } -}; - -const char div_name[] = "div_kernel"; -void div_true_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (iter.common_dtype() == kComplexHalf) { - using scalar_t = c10::complex; - #if AT_USE_JITERATOR() - static const auto div_string = jiterator_stringify( - template - T div_kernel(T a, T b) { - return a / b; - } - ); - opmath_jitted_gpu_kernel_with_scalars(iter, div_string); - #else - using opmath_t = at::opmath_type; - opmath_gpu_kernel_with_scalars(iter, DivFunctor()); - #endif - return; - } - if (iter.is_cpu_scalar(2)) { - // optimization for floating-point types: if the second operand is a CPU - // scalar, compute a * reciprocal(b). Note that this may lose one bit of - // precision compared to computing the division. - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { - using opmath_t = at::opmath_type; - auto inv_b = opmath_t(1.0) / iter.scalar_value(2); - iter.remove_operand(2); - gpu_kernel(iter, BUnaryFunctor>( - MulFunctor(), inv_b)); - }); - } else { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { - DivFunctor f; - gpu_kernel_with_scalars(iter, f); - }); - } -} - -void div_trunc_kernel_cuda(TensorIteratorBase& iter) { - auto dtype = iter.common_dtype(); - if (isIntegralType(dtype, /*includeBool*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - return a / b; - }); - }); - } else if (iter.is_cpu_scalar(2)) { - // optimization for floating-point types: if the second operand is a CPU - // scalar, compute a * reciprocal(b). Note that this may lose one bit of - // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { - using accscalar_t = at::acc_type; - auto inv_b = accscalar_t(1.0) / iter.scalar_value(2); - iter.remove_operand(2); - gpu_kernel(iter, [inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t { - return std::trunc(a * inv_b); - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - return std::trunc(a / b); - }); - }); - } -} - -void div_floor_kernel_cuda(TensorIteratorBase& iter) { - // See NOTE: [Floor Division in Python] - const auto dtype = iter.common_dtype(); - if (dtype == kByte) { - // In the special case of unsigned integer division, floor division is - // equivalent to truncation division (since the signs of the divisor and - // dividend are always the same) - return div_trunc_kernel_cuda(iter); - } else if (isIntegralType(dtype, /*includeBool*/ false)) { - AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_floor_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - if (c10::signs_differ(a, b)) { - // Subtracts one from the results of truncation division if the - // divisor and dividend have different sign(bit)s and the remainder of - // the division is nonzero - const auto quot = a / b; - const auto rem = a % b; - return rem ? quot - 1 : quot; - } - - return a / b; - }); - }); - } else if (iter.is_cpu_scalar(2)) { - // optimization for floating-point types: if the second operand is a CPU - // scalar, compute a * reciprocal(b). Note that this may lose one bit of - // precision compared to computing the division. - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { - using accscalar_t = at::acc_type; - auto b = iter.scalar_value(2); - if (C10_UNLIKELY(b == 0)) { - return div_true_kernel_cuda(iter); - } - - auto inv_b = accscalar_t(1.0) / b; - iter.remove_operand(2); - gpu_kernel(iter, [b, inv_b] GPU_LAMBDA (scalar_t a) -> scalar_t { - auto mod = std::fmod(a, b); - auto div = (a - mod) * inv_b; - if ((mod != 0) && (b < 0) != (mod < 0)) { - div -= scalar_t(1); - } - - scalar_t floordiv; - if (div != 0) { - floordiv = std::floor(div); - if (div - floordiv > scalar_t(0.5)) { - floordiv += scalar_t(1.0); - } - } else { - floordiv = c10::cuda::compat::copysign(scalar_t(0), a * inv_b); - } - return floordiv; - }); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, dtype, "div_floor_cuda", [&]() { - gpu_kernel_with_scalars(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t { - if (C10_UNLIKELY(b == 0)) { - return a / b; - } - - auto mod = std::fmod(a, b); - auto div = (a - mod) / b; - if ((mod != 0) && (b < 0) != (mod < 0)) { - div -= scalar_t(1); - } - - scalar_t floordiv; - if (div != 0) { - floordiv = std::floor(div); - if (div - floordiv > scalar_t(0.5)) { - floordiv += scalar_t(1.0); - } - } else { - floordiv = c10::cuda::compat::copysign(scalar_t(0), a / b); - } - return floordiv; - }); - }); - } -} - -const char mul_name[] = "mul_kernel"; -void mul_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (common_dtype == kComplexHalf) { - using scalar_t = c10::complex; - #if AT_USE_JITERATOR() - static const auto mul_string = jiterator_stringify( - template - T mul_kernel(T a, T b) { - return a * b; - } - ); - opmath_jitted_gpu_kernel_with_scalars(iter, mul_string); - #else - using opmath_t = at::opmath_type; - opmath_symmetric_gpu_kernel_with_scalars(iter, MulFunctor()); - #endif - } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { - using opmath_t = at::opmath_type; - opmath_symmetric_gpu_kernel_with_scalars(iter, MulFunctor()); - }); - } -} - -REGISTER_DISPATCH(div_true_stub, &div_true_kernel_cuda); -REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel_cuda); -REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel_cuda); -REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/BinaryMulKernel.cu b/aten/src/ATen/native/cuda/BinaryMulKernel.cu new file mode 100644 index 0000000000000..b0b4f4886ab85 --- /dev/null +++ b/aten/src/ATen/native/cuda/BinaryMulKernel.cu @@ -0,0 +1,50 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +// NOTE: CUDA on Windows requires that the enclosing function +// of a __device__ lambda not have internal linkage. + +namespace at { +namespace native { + +const char mul_name[] = "mul_kernel"; +void mul_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (common_dtype == kComplexHalf) { + using scalar_t = c10::complex; +#if AT_USE_JITERATOR() + static const auto mul_string = jiterator_stringify( + template T mul_kernel(T a, T b) { return a * b; }); + opmath_jitted_gpu_kernel_with_scalars( + iter, mul_string); +#else + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); +#endif + } else { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( + kHalf, kBFloat16, kBool, iter.common_dtype(), "mul_cuda", [&]() { + using opmath_t = at::opmath_type; + opmath_symmetric_gpu_kernel_with_scalars( + iter, binary_internal::MulFunctor()); + }); + } +} + +REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu new file mode 100644 index 0000000000000..292404cb36acb --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu @@ -0,0 +1,51 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + MinMaxOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), + at::numeric_limits::lower_bound())); +} + +void aminmax_allreduce_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { + _min_max_values_kernel_cuda_impl(iter); + }); +} + +void aminmax_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { + gpu_reduce_kernel( + iter, + MinMaxOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), + at::numeric_limits::lower_bound())); + }); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu new file mode 100644 index 0000000000000..fd8e071cd5c8d --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceArgMaxKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +void argmax_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + ArgMaxOps{}, + thrust::pair( + at::numeric_limits::lower_bound(), 0)); +}; + +void argmax_kernel_cuda(TensorIterator& iter) { + // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, + // we can convert float16 & bfloat16 to float and do all the operations in + // float. + if (iter.dtype(1) == kHalf) { + argmax_kernel_cuda_impl(iter); + } else if (iter.dtype(1) == kBFloat16) { + argmax_kernel_cuda_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { + argmax_kernel_cuda_impl(iter); + }); + } +} + +REGISTER_DISPATCH(argmax_stub, &argmax_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu b/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu new file mode 100644 index 0000000000000..20eb736e49457 --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceArgMinKernel.cu @@ -0,0 +1,48 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +void argmin_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + ArgMinOps{}, + thrust::pair( + at::numeric_limits::upper_bound(), 0)); +}; + +void argmin_kernel_cuda(TensorIterator& iter) { + // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, + // we can convert float16 & bfloat16 to float and do all the operations in + // float. + if (iter.dtype(1) == kHalf) { + argmin_kernel_cuda_impl(iter); + } else if (iter.dtype(1) == kBFloat16) { + argmin_kernel_cuda_impl(iter); + } else { + AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() { + argmin_kernel_cuda_impl(iter); + }); + } +} + +REGISTER_DISPATCH(argmin_stub, &argmin_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu new file mode 100644 index 0000000000000..a5363838ee257 --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu @@ -0,0 +1,63 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { +namespace native { + +template +struct MaxNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (at::_isnan(a) || a > b) ? a : b; + } +}; + +template +void max_values_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, + func_wrapper(MaxNanFunctor()), + at::numeric_limits::lower_bound()); +} + +void max_values_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { + max_values_kernel_cuda_impl(iter); + }); +} + +void max_launch_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { + gpu_reduce_kernel( + iter, + MaxOps{}, + thrust::pair( + at::numeric_limits::lower_bound(), 0)); + }); +} + +void max_all_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { + max_values_kernel_cuda_impl(iter); + }); +} + +REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu deleted file mode 100644 index db9e9b5e60aff..0000000000000 --- a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu +++ /dev/null @@ -1,168 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - - -namespace at { namespace native { - -template -struct MaxNanFunctor { - __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { - return (at::_isnan(a) || a > b) ? a : b; - } -}; - -template -void max_values_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, func_wrapper (MaxNanFunctor()), - at::numeric_limits::lower_bound()); -} - -template -struct MinNanFunctor { - __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { - return (at::_isnan(a) || a < b) ? a : b; - } -}; - -template -void min_values_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, func_wrapper (MinNanFunctor()), - at::numeric_limits::upper_bound()); -} - -void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { - max_values_kernel_cuda_impl(iter); - }); -} - -void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { - min_values_kernel_cuda_impl(iter); - }); -} - -template -void argmax_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, - ArgMaxOps{}, - thrust::pair(at::numeric_limits::lower_bound(), 0)); -}; - -template -void argmin_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, - ArgMinOps{}, - thrust::pair(at::numeric_limits::upper_bound(), 0)); -}; - -void argmax_kernel_cuda(TensorIterator& iter) { - // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, - // we can convert float16 & bfloat16 to float and do all the operations in float. - if (iter.dtype(1) == kHalf) { - argmax_kernel_cuda_impl(iter); - } else if (iter.dtype(1) == kBFloat16) { - argmax_kernel_cuda_impl(iter); - } else { - AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmax_cuda", [&]() { - argmax_kernel_cuda_impl(iter); - }); - } -} - -void argmin_kernel_cuda(TensorIterator& iter) { - // For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down, - // we can convert float16 & bfloat16 to float and do all the operations in float. - if (iter.dtype(1) == kHalf) { - argmin_kernel_cuda_impl(iter); - } else if (iter.dtype(1) == kBFloat16) { - argmin_kernel_cuda_impl(iter); - } else { - AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() { - argmin_kernel_cuda_impl(iter); - }); - } -} - -void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { - gpu_reduce_kernel( - iter, - MinOps{}, - thrust::pair(at::numeric_limits::upper_bound(), 0)); - }); -} - -void max_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { - gpu_reduce_kernel( - iter, - MaxOps{}, - thrust::pair(at::numeric_limits::lower_bound(), 0)); - }); -} - -void aminmax_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { - gpu_reduce_kernel( - iter, - MinMaxOps{}, - thrust::pair( - at::numeric_limits::upper_bound(), - at::numeric_limits::lower_bound() - ) - ); - }); -} - -void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { - min_values_kernel_cuda_impl(iter); - }); -} - -void max_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { - max_values_kernel_cuda_impl(iter); - }); -} - -template -void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { - gpu_reduce_kernel( - iter, MinMaxOps{}, thrust::pair( - at::numeric_limits::upper_bound(), - at::numeric_limits::lower_bound() - )); -} - -void aminmax_allreduce_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { - _min_max_values_kernel_cuda_impl(iter); - }); -} - -REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda); -REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda); -REGISTER_DISPATCH(argmax_stub, &argmax_kernel_cuda); -REGISTER_DISPATCH(argmin_stub, &argmin_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu new file mode 100644 index 0000000000000..54d0f8499e541 --- /dev/null +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -0,0 +1,58 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + + +namespace at { namespace native { + +template +struct MinNanFunctor { + __device__ __forceinline__ acc_t operator()(acc_t a, acc_t b) const { + return (at::_isnan(a) || a < b) ? a : b; + } +}; + +template +void min_values_kernel_cuda_impl(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper (MinNanFunctor()), + at::numeric_limits::upper_bound()); +} + +void min_values_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { + min_values_kernel_cuda_impl(iter); + }); +} + +void min_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { + gpu_reduce_kernel( + iter, + MinOps{}, + thrust::pair(at::numeric_limits::upper_bound(), 0)); + }); +} + +void min_all_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { + min_values_kernel_cuda_impl(iter); + }); +} + +REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index bdd42d6559223..e5e3274fd69a9 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -280,220 +280,4 @@ void sortKeyValueInplace( } } -namespace { - -struct offset_t { - int stride; - int begin; - __device__ int operator[](int i) { - return stride * (begin + i); - } -}; - -} - -namespace { - -// Segmented sort by full sort algorithm:. -// Say we are sorting a (2, 3) tensor. We have in flattened form: -// values 0.4 1.2 5.3 6.2 1.3 2.3 -// indices 0 1 2 0 1 2 -// segment_id 0 0 0 1 1 1 - -// First we sort by values, globally: -// values 6.2 5.3 2.3 1.2 1.3 0.4 -// indices 0 2 2 1 1 0 -// segment_id 1 0 1 0 1 0 - -// Then we stable sort by segment id: -// values 5.3 1.2 0.4 6.2 2.3 1.3 -// indices 2 1 0 0 2 1 -// segment_id 0 0 0 1 1 1 - -// This method can only work if the slice we are sorting (`dim`) is -// innermost, and both values and indices are contiguous. We do this -// by re-arranging the input into this form as needed, which will -// unfortunately allocate memory if the request is not in this form. -// Vectorized sort is slower than iterated sort if the number of -// slices is small (since we're sorting twice, instead of invoking a -// smaller sort `numSlices` times), but the cub sort -// implementation here is a catch-all, so we're not looking for -// efficiency, but instead correctness. - -template -__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) { - CUDA_KERNEL_LOOP(i, nsegments * nsort) { - int segment = i / nsort; - int j = i % nsort; - - int offset = segment * nsort; - const scalar_t *in_ = in + offset; - scalar_t *out_ = out + offset; - int64_t *index_ = index + offset; - const int2 *i_s_ptr_ = i_s_ptr + offset; - - int idx = i_s_ptr_[j].y; - index_[j] = idx; - out_[j] = in_[idx]; - } -} - - -C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) -__global__ void fill_index_and_segment_kernel( - int2 *data, int numel, at::cuda::detail::IntDivider nsort_divider) { - CUDA_KERNEL_LOOP(idx, numel) { - auto div_mod = nsort_divider.divmod(idx); - auto segment = static_cast(div_mod.div); - auto sort = static_cast(div_mod.mod); - data[idx] = int2{segment, sort}; - } -} - -C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) -__global__ void fill_reverse_indices_kernel( - int64_t *data, int numel, at::cuda::detail::IntDivider nsort_divider) { - CUDA_KERNEL_LOOP(idx, numel) { - data[idx] = nsort_divider.mod(idx); - } -} - -template -inline void segmented_sort_large_segments( - const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending, - const scalar_t * self_ptr, scalar_t * values_ptr, int64_t * indices_ptr - ) { - using namespace at::cuda::detail; - auto allocator = at::cuda::getCUDADeviceAllocator(); - auto stream = at::cuda::getCurrentCUDAStream(); - dim3 block = CUDA_NUM_THREADS; - dim3 grid = GET_BLOCKS(nsort); - c10::DeviceArray indices(*allocator, nsort); - at::cuda::detail::IntDivider nsort_divider(nsort); - fill_reverse_indices_kernel<<>>( - indices.get(), nsort, nsort_divider); - const int64_t *initial_indices = indices.get(); - - for (auto i: c10::irange(nsegments)){ - at::cuda::cub::radix_sort_pairs( - self_ptr, values_ptr, initial_indices, indices_ptr, - nsort, descending); - indices_ptr += nsort; - self_ptr += nsort; - values_ptr += nsort; - } -} - -template -inline void segmented_sort_pairs_by_full_sort( - const int64_t nsegments, const int64_t nsort, const int64_t n, const bool descending, - const scalar_t *const self_ptr, scalar_t *const values_ptr, int64_t *const indices_ptr -) { - int64_t segment_bits = std::max(1L, static_cast(std::ceil(std::log2(nsegments)))); - - const auto numel = nsort * nsegments; - auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); - auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2)); - auto i_s_ptr = static_cast(indices_and_segment.get()); - - using namespace at::cuda::detail; - dim3 block = CUDA_NUM_THREADS; - dim3 grid = GET_BLOCKS(numel); - auto stream = c10::cuda::getCurrentCUDAStream(); - at::cuda::detail::IntDivider nsort_divider(nsort); - fill_index_and_segment_kernel<<>>( - i_s_ptr, numel, nsort_divider); - - auto indices_and_segment2 = cuda_allocator->allocate(nsegments * nsort * sizeof(int2)); - auto i_s_ptr2 = static_cast(indices_and_segment2.get()); - - at::cuda::cub::radix_sort_pairs( - self_ptr, nullptr, i_s_ptr, i_s_ptr2, - n, descending); - - TORCH_INTERNAL_ASSERT(segment_bits <= 32); - - // sort on lower 32bits, i.e. segment index - at::cuda::cub::radix_sort_keys( - reinterpret_cast(i_s_ptr2), reinterpret_cast(i_s_ptr), - n, false, 0, segment_bits); - - sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( - self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); -} - -template -void segmented_sort_pairs( - int64_t nsegments, int64_t nsort, int64_t n, bool descending, - const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr) { - const auto numel = nsort * nsegments; - auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); - auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t)); - int64_t *reverse_indices_ptr = static_cast(reverse_indices.get()); - - using namespace at::cuda::detail; - dim3 block = CUDA_NUM_THREADS; - dim3 grid = GET_BLOCKS(numel); - auto stream = c10::cuda::getCurrentCUDAStream(); - at::cuda::detail::IntDivider nsort_divider(nsort); - fill_reverse_indices_kernel<<>>( - reverse_indices_ptr, numel, nsort_divider); - - at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, - reverse_indices_ptr, indices_ptr, n, nsegments, - offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); -} - -} // namespace - -void launch_stable_sort_kernel( - const TensorBase &self, int64_t dim, bool descending, - const TensorBase &values, const TensorBase &indices) { - const auto numel = self.numel(); - if (numel == 0) { - return; - } - - int64_t numel_or_intmax = std::min(numel, static_cast(std::numeric_limits::max())); - int64_t nsort = self.size(dim); - int64_t nbatch = (numel_or_intmax / nsort) * nsort; - TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); - int64_t *indices_ptr = indices.data_ptr(); - -#if (defined(USE_ROCM) && ROCM_VERSION < 40500) - constexpr bool is_rocm_bf16_sort_unsupported = true; -#else - constexpr bool is_rocm_bf16_sort_unsupported = false; -#endif - - AT_DISPATCH_ALL_TYPES_AND3(kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&]{ - c10::guts::if_constexpr::value)>([&](auto _){ - const scalar_t *self_ptr = self.data_ptr(); - scalar_t *values_ptr = values.data_ptr(); - int64_t remaining = _(numel); - while (remaining > 0) { - int64_t n = std::min(remaining, nbatch); - int64_t nsegments = n / nsort; - - if (nsegments == 1 || nsort >= 1000000) { //rough heuristics where even a single sort occupies GPU - segmented_sort_large_segments( - nsegments, nsort, n, descending, - self_ptr, values_ptr, indices_ptr); - } else if (nsegments < 128) { - segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending, - self_ptr, values_ptr, indices_ptr); - } else { - segmented_sort_pairs(nsegments, nsort, n, descending, - self_ptr, values_ptr, indices_ptr); - } - - remaining -= n; - self_ptr += n; - values_ptr += n; - indices_ptr += n; - } - }, [&](auto _){ TORCH_CHECK(_(false), "BFloat16 is not supported on ROCm < 4.5"); }); - }); -} - }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Sort.h b/aten/src/ATen/native/cuda/Sort.h index f81b1fdd42839..656b4ce2c2bba 100644 --- a/aten/src/ATen/native/cuda/Sort.h +++ b/aten/src/ATen/native/cuda/Sort.h @@ -1,16 +1,11 @@ #pragma once #include #include +#include namespace at { namespace native { -// Stable-sort self into values, and set indices to the -// inverse-permutation from values back to self. -// Output tensors must be pre-allocated and contiguous. -void launch_stable_sort_kernel(const TensorBase &self, int64_t dim, bool descending, - const TensorBase &values, const TensorBase &indices); - inline bool should_use_small_sort(const TensorBase &self, int64_t dim) { return self.size(dim) <= 4096; } diff --git a/aten/src/ATen/native/cuda/SortStable.cu b/aten/src/ATen/native/cuda/SortStable.cu new file mode 100644 index 0000000000000..cf6ffb778e57a --- /dev/null +++ b/aten/src/ATen/native/cuda/SortStable.cu @@ -0,0 +1,299 @@ + +#define TORCH_ASSERT_NO_OPERATORS +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +namespace native { + +namespace { + +struct offset_t { + int stride; + int begin; + __device__ int operator[](int i) { + return stride * (begin + i); + } +}; +// Segmented sort by full sort algorithm:. +// Say we are sorting a (2, 3) tensor. We have in flattened form: +// values 0.4 1.2 5.3 6.2 1.3 2.3 +// indices 0 1 2 0 1 2 +// segment_id 0 0 0 1 1 1 + +// First we sort by values, globally: +// values 6.2 5.3 2.3 1.2 1.3 0.4 +// indices 0 2 2 1 1 0 +// segment_id 1 0 1 0 1 0 + +// Then we stable sort by segment id: +// values 5.3 1.2 0.4 6.2 2.3 1.3 +// indices 2 1 0 0 2 1 +// segment_id 0 0 0 1 1 1 + +// This method can only work if the slice we are sorting (`dim`) is +// innermost, and both values and indices are contiguous. We do this +// by re-arranging the input into this form as needed, which will +// unfortunately allocate memory if the request is not in this form. +// Vectorized sort is slower than iterated sort if the number of +// slices is small (since we're sorting twice, instead of invoking a +// smaller sort `numSlices` times), but the cub sort +// implementation here is a catch-all, so we're not looking for +// efficiency, but instead correctness. + +template +__global__ void sort_postprocess_kernel( + const scalar_t* in, + scalar_t* out, + int64_t* index, + const int2* i_s_ptr, + int nsegments, + int nsort) { + CUDA_KERNEL_LOOP(i, nsegments * nsort) { + int segment = i / nsort; + int j = i % nsort; + + int offset = segment * nsort; + const scalar_t* in_ = in + offset; + scalar_t* out_ = out + offset; + int64_t* index_ = index + offset; + const int2* i_s_ptr_ = i_s_ptr + offset; + + int idx = i_s_ptr_[j].y; + index_[j] = idx; + out_[j] = in_[idx]; + } +} + +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) +__global__ void fill_index_and_segment_kernel( + int2* data, + int numel, + at::cuda::detail::IntDivider nsort_divider) { + CUDA_KERNEL_LOOP(idx, numel) { + auto div_mod = nsort_divider.divmod(idx); + auto segment = static_cast(div_mod.div); + auto sort = static_cast(div_mod.mod); + data[idx] = int2{segment, sort}; + } +} + +C10_LAUNCH_BOUNDS_1(at::cuda::detail::CUDA_NUM_THREADS) +__global__ void fill_reverse_indices_kernel( + int64_t* data, + int numel, + at::cuda::detail::IntDivider nsort_divider) { + CUDA_KERNEL_LOOP(idx, numel) { + data[idx] = nsort_divider.mod(idx); + } +} + +template +inline void segmented_sort_large_segments( + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending, + const scalar_t* self_ptr, + scalar_t* values_ptr, + int64_t* indices_ptr) { + using namespace at::cuda::detail; + auto allocator = at::cuda::getCUDADeviceAllocator(); + auto stream = at::cuda::getCurrentCUDAStream(); + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(nsort); + c10::DeviceArray indices(*allocator, nsort); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_reverse_indices_kernel<<>>( + indices.get(), nsort, nsort_divider); + const int64_t* initial_indices = indices.get(); + + for (auto i : c10::irange(nsegments)) { + at::cuda::cub::radix_sort_pairs( + self_ptr, values_ptr, initial_indices, indices_ptr, nsort, descending); + indices_ptr += nsort; + self_ptr += nsort; + values_ptr += nsort; + } +} + +template +inline void segmented_sort_pairs_by_full_sort( + const int64_t nsegments, + const int64_t nsort, + const int64_t n, + const bool descending, + const scalar_t* const self_ptr, + scalar_t* const values_ptr, + int64_t* const indices_ptr) { + int64_t segment_bits = std::max( + 1L, static_cast(std::ceil(std::log2(nsegments)))); + + const auto numel = nsort * nsegments; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto indices_and_segment = cuda_allocator->allocate(numel * sizeof(int2)); + auto i_s_ptr = static_cast(indices_and_segment.get()); + + using namespace at::cuda::detail; + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_index_and_segment_kernel<<>>( + i_s_ptr, numel, nsort_divider); + + auto indices_and_segment2 = + cuda_allocator->allocate(nsegments * nsort * sizeof(int2)); + auto i_s_ptr2 = static_cast(indices_and_segment2.get()); + + at::cuda::cub::radix_sort_pairs( + self_ptr, nullptr, i_s_ptr, i_s_ptr2, n, descending); + + TORCH_INTERNAL_ASSERT(segment_bits <= 32); + + // sort on lower 32bits, i.e. segment index + at::cuda::cub::radix_sort_keys( + reinterpret_cast(i_s_ptr2), + reinterpret_cast(i_s_ptr), + n, + false, + 0, + segment_bits); + + sort_postprocess_kernel<<< + (n + 511) / 512, + 512, + 0, + at::cuda::getCurrentCUDAStream()>>>( + self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); +} + +template +void segmented_sort_pairs( + int64_t nsegments, + int64_t nsort, + int64_t n, + bool descending, + const scalar_t* self_ptr, + scalar_t* values_ptr, + int64_t* indices_ptr) { + const auto numel = nsort * nsegments; + auto cuda_allocator = at::cuda::getCUDADeviceAllocator(); + auto reverse_indices = cuda_allocator->allocate(numel * sizeof(int64_t)); + int64_t* reverse_indices_ptr = static_cast(reverse_indices.get()); + + using namespace at::cuda::detail; + dim3 block = CUDA_NUM_THREADS; + dim3 grid = GET_BLOCKS(numel); + auto stream = c10::cuda::getCurrentCUDAStream(); + at::cuda::detail::IntDivider nsort_divider(nsort); + fill_reverse_indices_kernel<<>>( + reverse_indices_ptr, numel, nsort_divider); + + at::cuda::cub::segmented_sort_pairs( + self_ptr, + values_ptr, + reverse_indices_ptr, + indices_ptr, + n, + nsegments, + offset_t{(int)nsort, 0}, + offset_t{(int)nsort, 1}, + descending); +} + +} // namespace + +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices) { + const auto numel = self.numel(); + if (numel == 0) { + return; + } + + int64_t numel_or_intmax = + std::min(numel, static_cast(std::numeric_limits::max())); + int64_t nsort = self.size(dim); + int64_t nbatch = (numel_or_intmax / nsort) * nsort; + TORCH_CHECK(nbatch > 0, "Cannot sort dimension of length ", nsort); + int64_t* indices_ptr = indices.data_ptr(); + +#if (defined(USE_ROCM) && ROCM_VERSION < 40500) + constexpr bool is_rocm_bf16_sort_unsupported = true; +#else + constexpr bool is_rocm_bf16_sort_unsupported = false; +#endif + + AT_DISPATCH_ALL_TYPES_AND3( + kBool, kHalf, kBFloat16, self.scalar_type(), "sort", [&] { + c10::guts::if_constexpr::value)>( + [&](auto _) { + const scalar_t* self_ptr = self.data_ptr(); + scalar_t* values_ptr = values.data_ptr(); + int64_t remaining = _(numel); + while (remaining > 0) { + int64_t n = std::min(remaining, nbatch); + int64_t nsegments = n / nsort; + + if (nsegments == 1 || + nsort >= 1000000) { // rough heuristics where even a single + // sort occupies GPU + segmented_sort_large_segments( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } else if (nsegments < 128) { + segmented_sort_pairs_by_full_sort( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } else { + segmented_sort_pairs( + nsegments, + nsort, + n, + descending, + self_ptr, + values_ptr, + indices_ptr); + } + + remaining -= n; + self_ptr += n; + values_ptr += n; + indices_ptr += n; + } + }, + [&](auto _) { + TORCH_CHECK(_(false), "BFloat16 is not supported on ROCm < 4.5"); + }); + }); +} +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/SortStable.h b/aten/src/ATen/native/cuda/SortStable.h new file mode 100644 index 0000000000000..039c4307c522c --- /dev/null +++ b/aten/src/ATen/native/cuda/SortStable.h @@ -0,0 +1,19 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Stable-sort self into values, and set indices to the +// inverse-permutation from values back to self. +// Output tensors must be pre-allocated and contiguous. +void launch_stable_sort_kernel( + const TensorBase& self, + int64_t dim, + bool descending, + const TensorBase& values, + const TensorBase& indices); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu new file mode 100644 index 0000000000000..d27a74c19c960 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char acos_name[] = "acos"; +void acos_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto acos_string = jiterator_stringify( + template T acos(T a) { return std::acos(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "acos_name", [&]() { + jitted_gpu_kernel< + /*name=*/acos_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, acos_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "acos_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::acos(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "acos_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::acos(a); + }); + }); + } +} + +REGISTER_DISPATCH(acos_stub, &acos_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu new file mode 100644 index 0000000000000..f831e9e5b8710 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char acosh_name[] = "acosh"; +void acosh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if(at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto acosh_string = jiterator_stringify( + template + T acosh(T a) { + return std::acosh(a); + } + ); + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() { + jitted_gpu_kernel< + /*name=*/ acosh_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, acosh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::acosh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + common_dtype, "acosh_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::acosh(a); + }); + }); + } +} + +REGISTER_DISPATCH(acosh_stub, &acosh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu new file mode 100644 index 0000000000000..fdabb67717741 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu @@ -0,0 +1,52 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char asin_name[] = "asin"; +void asin_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto asin_string = jiterator_stringify( + template T asin(T a) { return std::asin(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asin_name", [&]() { + jitted_gpu_kernel< + /*name=*/asin_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, asin_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asin_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::asin(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, common_dtype, "asin_cuda", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::asin(a); + }); + }); + } +} + +REGISTER_DISPATCH(asin_stub, &asin_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu new file mode 100644 index 0000000000000..e1cf41b46db50 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char asinh_name[] = "asinh"; +void asinh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto asinh_string = jiterator_stringify( + template T asinh(T a) { return std::asinh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asinh_name", [&]() { + jitted_gpu_kernel< + /*name=*/asinh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, asinh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "asinh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::asinh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "asinh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::asinh(a); + }); + }); + } +} + +REGISTER_DISPATCH(asinh_stub, &asinh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu new file mode 100644 index 0000000000000..4bbbb95d384f9 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char atan_name[] = "atan"; +void atan_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto atan_string = jiterator_stringify( + template + T atan(T a) { + return std::atan(a); + } + ); + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { + jitted_gpu_kernel< + /*name=*/ atan_name, + /*return_dtype=*/ scalar_t, + /*common_dtype=*/ scalar_t, + /*arity=*/ 1>(iter, atan_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::atan(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, + common_dtype, "atan_cuda", + [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::atan(a); + }); + }); + } +} + +REGISTER_DISPATCH(atan_stub, &atan_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu new file mode 100644 index 0000000000000..461a0f042205d --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char atanh_name[] = "atanh"; +void atanh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto atanh_string = jiterator_stringify( + template T atanh(T a) { return std::atanh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "atanh_name", [&]() { + jitted_gpu_kernel< + /*name=*/atanh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, atanh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "atanh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::atanh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "atanh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::atanh(a); + }); + }); + } +} + +REGISTER_DISPATCH(atanh_stub, &atanh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu new file mode 100644 index 0000000000000..c246a1d332235 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char cos_name[] = "cos"; +void cos_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto cos_string = jiterator_stringify( + template T cos(T a) { return std::cos(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cos_name", [&]() { + jitted_gpu_kernel< + /*name=*/cos_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, cos_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cos_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::cos(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "cos_cuda", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::cos(a); }); + }); + } +} + +REGISTER_DISPATCH(cos_stub, &cos_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu new file mode 100644 index 0000000000000..1d1f479843286 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char cosh_name[] = "cosh"; +void cosh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto cosh_string = jiterator_stringify( + template T cosh(T a) { return std::cosh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cosh_name", [&]() { + jitted_gpu_kernel< + /*name=*/cosh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, cosh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "cosh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::cosh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "cosh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::cosh(a); + }); + }); + } +} + +REGISTER_DISPATCH(cosh_stub, &cosh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu b/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu deleted file mode 100644 index 2623e6c228376..0000000000000 --- a/aten/src/ATen/native/cuda/UnaryGeometricKernels.cu +++ /dev/null @@ -1,480 +0,0 @@ -#define TORCH_ASSERT_NO_OPERATORS -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at { namespace native { - -const char acos_name[] = "acos"; -void acos_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto acos_string = jiterator_stringify( - template - T acos(T a) { - return std::acos(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acos_name", [&]() { - jitted_gpu_kernel< - /*name=*/ acos_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, acos_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acos_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::acos(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "acos_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::acos(a); - }); - }); - } -} - -const char asin_name[] = "asin"; -void asin_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto asin_string = jiterator_stringify( - template - T asin(T a) { - return std::asin(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asin_name", [&]() { - jitted_gpu_kernel< - /*name=*/ asin_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, asin_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asin_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::asin(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, common_dtype, "asin_cuda", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::asin(a); - }); - }); - } -} - -const char atan_name[] = "atan"; -void atan_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto atan_string = jiterator_stringify( - template - T atan(T a) { - return std::atan(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { - jitted_gpu_kernel< - /*name=*/ atan_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, atan_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atan_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::atan(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "atan_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::atan(a); - }); - }); - } -} - -const char sin_name[] = "sin"; -void sin_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto sin_string = jiterator_stringify( - template - T sin(T a) { - return std::sin(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sin_name", [&]() { - jitted_gpu_kernel< - /*name=*/ sin_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, sin_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sin_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::sin(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "sin_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::sin(a); - }); - }); - } -} - -const char cos_name[] = "cos"; -void cos_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto cos_string = jiterator_stringify( - template - T cos(T a) { - return std::cos(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cos_name", [&]() { - jitted_gpu_kernel< - /*name=*/ cos_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, cos_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cos_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::cos(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "cos_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::cos(a); - }); - }); - } -} - -const char sinh_name[] = "sinh"; -void sinh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto sinh_string = jiterator_stringify( - template - T sinh(T a) { - return std::sinh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sinh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ sinh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, sinh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sinh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::sinh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "sinh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::sinh(a); - }); - }); - } -} - -const char cosh_name[] = "cosh"; -void cosh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto cosh_string = jiterator_stringify( - template - T cosh(T a) { - return std::cosh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cosh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ cosh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, cosh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "cosh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::cosh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "cosh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::cosh(a); - }); - }); - } -} - -const char tanh_name[] = "tanh"; -void tanh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto tanh_string = jiterator_stringify( - template - T tanh(T a) { - return std::tanh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tanh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ tanh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, tanh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tanh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::tanh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "tanh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::tanh(a); - }); - }); - } -} - -const char acosh_name[] = "acosh"; -void acosh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto acosh_string = jiterator_stringify( - template - T acosh(T a) { - return std::acosh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ acosh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, acosh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::acosh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "acosh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::acosh(a); - }); - }); - } -} - -const char asinh_name[] = "asinh"; -void asinh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto asinh_string = jiterator_stringify( - template - T asinh(T a) { - return std::asinh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asinh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ asinh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, asinh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "asinh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::asinh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "asinh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::asinh(a); - }); - }); - } -} - -const char atanh_name[] = "atanh"; -void atanh_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if(at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto atanh_string = jiterator_stringify( - template - T atanh(T a) { - return std::atanh(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atanh_name", [&]() { - jitted_gpu_kernel< - /*name=*/ atanh_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, atanh_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "atanh_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::atanh(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "atanh_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::atanh(a); - }); - }); - } -} - -const char tan_name[] = "tan"; -void tan_kernel_cuda(TensorIteratorBase& iter) { - auto common_dtype = iter.common_dtype(); - if (at::isComplexType(common_dtype)) { -#if AT_USE_JITERATOR - static const auto tan_string = jiterator_stringify( - template - T tan(T a) { - return std::tan(a); - } - ); - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tan_name", [&]() { - jitted_gpu_kernel< - /*name=*/ tan_name, - /*return_dtype=*/ scalar_t, - /*common_dtype=*/ scalar_t, - /*arity=*/ 1>(iter, tan_string); - }); -#else - AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "tan_name", [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - using opmath_t = at::opmath_type; - return ::tan(static_cast(a)); - }); - }); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2( - ScalarType::Half, ScalarType::BFloat16, - common_dtype, "tan_cuda", - [&]() { - gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return ::tan(a); - }); - }); - } -} - -REGISTER_DISPATCH(acos_stub, &acos_kernel_cuda); -REGISTER_DISPATCH(acosh_stub, &acosh_kernel_cuda); -REGISTER_DISPATCH(asinh_stub, &asinh_kernel_cuda); -REGISTER_DISPATCH(atanh_stub, &atanh_kernel_cuda); -REGISTER_DISPATCH(asin_stub, &asin_kernel_cuda); -REGISTER_DISPATCH(atan_stub, &atan_kernel_cuda); -REGISTER_DISPATCH(sin_stub, &sin_kernel_cuda); -REGISTER_DISPATCH(cos_stub, &cos_kernel_cuda); -REGISTER_DISPATCH(sinh_stub, &sinh_kernel_cuda); -REGISTER_DISPATCH(cosh_stub, &cosh_kernel_cuda); -REGISTER_DISPATCH(tanh_stub, &tanh_kernel_cuda); -REGISTER_DISPATCH(tan_stub, &tan_kernel_cuda); - -}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu new file mode 100644 index 0000000000000..833ecdccc18c2 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char sin_name[] = "sin"; +void sin_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto sin_string = jiterator_stringify( + template T sin(T a) { return std::sin(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sin_name", [&]() { + jitted_gpu_kernel< + /*name=*/sin_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, sin_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sin_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::sin(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "sin_cuda", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::sin(a); }); + }); + } +} + +REGISTER_DISPATCH(sin_stub, &sin_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu new file mode 100644 index 0000000000000..fb806aa84b66e --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char sinh_name[] = "sinh"; +void sinh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto sinh_string = jiterator_stringify( + template T sinh(T a) { return std::sinh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sinh_name", [&]() { + jitted_gpu_kernel< + /*name=*/sinh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, sinh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "sinh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::sinh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "sinh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::sinh(a); + }); + }); + } +} + +REGISTER_DISPATCH(sinh_stub, &sinh_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu new file mode 100644 index 0000000000000..a57499b337237 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu @@ -0,0 +1,55 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char tan_name[] = "tan"; +void tan_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto tan_string = jiterator_stringify( + template T tan(T a) { return std::tan(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tan_name", [&]() { + jitted_gpu_kernel< + /*name=*/tan_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, tan_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tan_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::tan(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "tan_cuda", + [&]() { + gpu_kernel( + iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::tan(a); }); + }); + } +} + +REGISTER_DISPATCH(tan_stub, &tan_kernel_cuda); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu new file mode 100644 index 0000000000000..ffaf36a028f61 --- /dev/null +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu @@ -0,0 +1,56 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { + +const char tanh_name[] = "tanh"; +void tanh_kernel_cuda(TensorIteratorBase& iter) { + auto common_dtype = iter.common_dtype(); + if (at::isComplexType(common_dtype)) { +#if AT_USE_JITERATOR + static const auto tanh_string = jiterator_stringify( + template T tanh(T a) { return std::tanh(a); }); + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tanh_name", [&]() { + jitted_gpu_kernel< + /*name=*/tanh_name, + /*return_dtype=*/scalar_t, + /*common_dtype=*/scalar_t, + /*arity=*/1>(iter, tanh_string); + }); +#else + AT_DISPATCH_COMPLEX_TYPES_AND( + kComplexHalf, common_dtype, "tanh_name", [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + using opmath_t = at::opmath_type; + return ::tanh(static_cast(a)); + }); + }); +#endif + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + common_dtype, + "tanh_cuda", + [&]() { + gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { + return ::tanh(a); + }); + }); + } +} + +REGISTER_DISPATCH(tanh_stub, &tanh_kernel_cuda); + +} // namespace native +} // namespace at