Skip to content

Commit

Permalink
[build] Split .cu to improve compile times (pytorch#81193)
Browse files Browse the repository at this point in the history
The goal is to speed up CUDA builds. I was looking at bulid times and found that we have large CUDA compilation units that take forever to compile and make parallelism less effective. This PR splits them up into different `.cu` files so we can parallelize compilation better. We've done this sort of thing in the past with some success.

With a cold build, timing before: 5m42.019s, timing after: 4m30.275s. That's a speedup of 18.1% for me.

Behaviorally this should be a no-op, I'm just moving code around. There is still more we can do here but I did most of the ones that are copypasta. The full list of remaining chonky compilation units is [here](https://gist.github.com/suo/0dc217733f40f59898a8cc4f60529d60).

## Details
Here's a screenshot from a ninja trace, with the following command:
```
MAX_JOBS=64 CCACHE_DISABLE=1 TORCH_CUDA_ARCH_LIST=Ampere BUILD_CAFFE2_OPS=0 USE_FBGEMM=0 USE_DISTRIBUTED=0 USE_MKLDNN=0 BUILD_TEST=0 USE_GOLD_LINKER=1 USE_OPENMP=1 USE_NCCL=0 DEBUG=0 python setup.py develop
```
<img width="1475" alt="image" src="https://user-images.githubusercontent.com/1617424/178170276-ee0e5eb0-4c16-4b86-b4af-2a9e615b7f5f.png">

([source trace](https://gist.github.com/suo/5f5458f2630f9ab6dcbea6989e892195), which you can visualize in [perfetto](https://ui.perfetto.dev/))

After this PR, we get somewhat better utilization (although there is plenty still left to do):
<img width="1466" alt="image" src="https://user-images.githubusercontent.com/1617424/178178944-63ca9ff0-9cd3-4008-9a6d-d8623b5148c5.png">

([source trace](https://gist.github.com/suo/5607335bcd4bd412d42b0c9334259184))
Pull Request resolved: pytorch#81193
Approved by: https://github.com/cpuhrsch, https://github.com/malfet
  • Loading branch information
suo authored and pytorchmergebot committed Jul 12, 2022
1 parent 282de55 commit fb93c39
Show file tree
Hide file tree
Showing 48 changed files with 2,952 additions and 1,914 deletions.
108 changes: 108 additions & 0 deletions aten/src/ATen/cuda/cub-RadixSortKeys.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/cub.cuh>

namespace at {
namespace cuda {
namespace cub {

template <typename key_t>
void radix_sort_keys(
const key_t* keys_in,
key_t* keys_out,
int64_t n,
bool descending,
int64_t begin_bit,
int64_t end_bit) {
TORCH_CHECK(
n <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;

const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
keys_in_,
keys_out_,
n,
begin_bit,
end_bit,
c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
keys_in_,
keys_out_,
n,
begin_bit,
end_bit,
c10::cuda::getCurrentCUDAStream());
}
}

template <typename scalar_t>
void unique(
const scalar_t* input,
scalar_t* output,
int64_t* num_selected_out,
int64_t num_items) {
TORCH_CHECK(
num_items <= std::numeric_limits<int>::max(),
"cub unique does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
input,
output,
num_selected_out,
num_items,
at::cuda::getCurrentCUDAStream());
}

template <typename scalar_t>
void run_length_encode(
const scalar_t* input,
scalar_t* output,
int64_t* counts_out,
int64_t* length_out,
int64_t num_items) {
TORCH_CHECK(
num_items <= std::numeric_limits<int>::max(),
"cub run_length_encode does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
input,
output,
counts_out,
length_out,
num_items,
at::cuda::getCurrentCUDAStream());
}

#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \
template void radix_sort_keys( \
const scalar_t* keys_in, \
scalar_t* keys_out, \
int64_t n, \
bool descending, \
int64_t begin_bit, \
int64_t end_bit); \
template void unique( \
const scalar_t* input, \
scalar_t* output, \
int64_t* num_selected_out, \
int64_t num_items); \
template void run_length_encode( \
const scalar_t* input, \
scalar_t* output, \
int64_t* counts_out, \
int64_t* length_out, \
int64_t n);

AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)

} // namespace cub
} // namespace cuda
} // namespace at
93 changes: 93 additions & 0 deletions aten/src/ATen/cuda/cub-RadixSortPairs.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/cub.cuh>

namespace at {
namespace cuda {
namespace cub {
namespace detail {

template <typename key_t, int value_size>
void radix_sort_pairs_impl(
const key_t* keys_in,
key_t* keys_out,
const OpaqueType<value_size>* values_in,
OpaqueType<value_size>* values_out,
int64_t n,
bool descending,
int64_t begin_bit,
int64_t end_bit) {
TORCH_CHECK(
n <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;

auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;

if (keys_out == nullptr) {
keys_out_owner = allocator->allocate(n * sizeof(key_t));
keys_out = reinterpret_cast<key_t*>(keys_out_owner.get());
}

const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
keys_in_,
keys_out_,
values_in,
values_out,
n,
begin_bit,
end_bit,
c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
keys_in_,
keys_out_,
values_in,
values_out,
n,
begin_bit,
end_bit,
c10::cuda::getCurrentCUDAStream());
}
}

#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \
template void radix_sort_pairs_impl( \
const key_t* keys_in, \
key_t* keys_out, \
const OpaqueType<value_size>* values_in, \
OpaqueType<value_size>* values_out, \
int64_t n, \
bool descending, \
int64_t begin_bit, \
int64_t end_bit);

AT_INSTANTIATE_SORT_PAIRS(int32_t, 1)
AT_INSTANTIATE_SORT_PAIRS(int32_t, 2)
AT_INSTANTIATE_SORT_PAIRS(int32_t, 4)
AT_INSTANTIATE_SORT_PAIRS(int64_t, 1)
AT_INSTANTIATE_SORT_PAIRS(int64_t, 2)
AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)

#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)

AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)

// BFloat16 Radix sort is supported from ROCm 4.5 onwards
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
#endif

} // namespace detail

} // namespace cub
} // namespace cuda
} // namespace at
112 changes: 0 additions & 112 deletions aten/src/ATen/cuda/cub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,118 +5,6 @@
namespace at {
namespace cuda {
namespace cub {
namespace detail {

template<typename key_t, int value_size>
void radix_sort_pairs_impl(
const key_t *keys_in, key_t *keys_out,
const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out,
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) {
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;

auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;

if (keys_out == nullptr) {
keys_out_owner = allocator->allocate(n * sizeof(key_t));
keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
}

const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
keys_in_, keys_out_, values_in, values_out, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
keys_in_, keys_out_, values_in, values_out, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
}

#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \
template void radix_sort_pairs_impl( \
const key_t *keys_in, key_t *keys_out, \
const OpaqueType<value_size> *values_in, \
OpaqueType<value_size> *values_out, \
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);

AT_INSTANTIATE_SORT_PAIRS(int32_t, 1)
AT_INSTANTIATE_SORT_PAIRS(int32_t, 2)
AT_INSTANTIATE_SORT_PAIRS(int32_t, 4)
AT_INSTANTIATE_SORT_PAIRS(int64_t, 1)
AT_INSTANTIATE_SORT_PAIRS(int64_t, 2)
AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)

#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)

AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)

// BFloat16 Radix sort is supported from ROCm 4.5 onwards
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
#endif

} // namespace detail

template<typename key_t>
void radix_sort_keys(
const key_t *keys_in, key_t *keys_out,
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) {
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"cub sort does not support sorting more than INT_MAX elements");
using key_t_ = typename detail::cuda_type<key_t>::type;

const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
keys_in_, keys_out_, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
keys_in_, keys_out_, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
}

template<typename scalar_t>
void unique(const scalar_t *input, scalar_t *output, int64_t *num_selected_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub unique does not support more than INT_MAX elements");
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
}

template <typename scalar_t>
void run_length_encode(const scalar_t *input, scalar_t *output, int64_t *counts_out,
int64_t *length_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub run_length_encode does not support more than INT_MAX elements");
CUB_WRAPPER(
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
input, output, counts_out, length_out, num_items,
at::cuda::getCurrentCUDAStream());
}

#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \
template void radix_sort_keys( \
const scalar_t *keys_in, scalar_t *keys_out, int64_t n, \
bool descending, int64_t begin_bit, int64_t end_bit); \
template void unique( \
const scalar_t *input, scalar_t *output, \
int64_t *num_selected_out, int64_t num_items); \
template void run_length_encode( \
const scalar_t *input, scalar_t *output, int64_t *counts_out, \
int64_t *length_out, int64_t n);

AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)

namespace {
template <typename scalar_t>
Expand Down
Loading

0 comments on commit fb93c39

Please sign in to comment.