forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[build] Split
.cu
to improve compile times (pytorch#81193)
The goal is to speed up CUDA builds. I was looking at bulid times and found that we have large CUDA compilation units that take forever to compile and make parallelism less effective. This PR splits them up into different `.cu` files so we can parallelize compilation better. We've done this sort of thing in the past with some success. With a cold build, timing before: 5m42.019s, timing after: 4m30.275s. That's a speedup of 18.1% for me. Behaviorally this should be a no-op, I'm just moving code around. There is still more we can do here but I did most of the ones that are copypasta. The full list of remaining chonky compilation units is [here](https://gist.github.com/suo/0dc217733f40f59898a8cc4f60529d60). ## Details Here's a screenshot from a ninja trace, with the following command: ``` MAX_JOBS=64 CCACHE_DISABLE=1 TORCH_CUDA_ARCH_LIST=Ampere BUILD_CAFFE2_OPS=0 USE_FBGEMM=0 USE_DISTRIBUTED=0 USE_MKLDNN=0 BUILD_TEST=0 USE_GOLD_LINKER=1 USE_OPENMP=1 USE_NCCL=0 DEBUG=0 python setup.py develop ``` <img width="1475" alt="image" src="https://user-images.githubusercontent.com/1617424/178170276-ee0e5eb0-4c16-4b86-b4af-2a9e615b7f5f.png"> ([source trace](https://gist.github.com/suo/5f5458f2630f9ab6dcbea6989e892195), which you can visualize in [perfetto](https://ui.perfetto.dev/)) After this PR, we get somewhat better utilization (although there is plenty still left to do): <img width="1466" alt="image" src="https://user-images.githubusercontent.com/1617424/178178944-63ca9ff0-9cd3-4008-9a6d-d8623b5148c5.png"> ([source trace](https://gist.github.com/suo/5607335bcd4bd412d42b0c9334259184)) Pull Request resolved: pytorch#81193 Approved by: https://github.com/cpuhrsch, https://github.com/malfet
- Loading branch information
1 parent
282de55
commit fb93c39
Showing
48 changed files
with
2,952 additions
and
1,914 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
#define TORCH_ASSERT_NO_OPERATORS | ||
#include <ATen/cuda/CUDAConfig.h> | ||
#include <ATen/cuda/cub.cuh> | ||
|
||
namespace at { | ||
namespace cuda { | ||
namespace cub { | ||
|
||
template <typename key_t> | ||
void radix_sort_keys( | ||
const key_t* keys_in, | ||
key_t* keys_out, | ||
int64_t n, | ||
bool descending, | ||
int64_t begin_bit, | ||
int64_t end_bit) { | ||
TORCH_CHECK( | ||
n <= std::numeric_limits<int>::max(), | ||
"cub sort does not support sorting more than INT_MAX elements"); | ||
using key_t_ = typename detail::cuda_type<key_t>::type; | ||
|
||
const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in); | ||
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out); | ||
|
||
if (descending) { | ||
CUB_WRAPPER( | ||
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending, | ||
keys_in_, | ||
keys_out_, | ||
n, | ||
begin_bit, | ||
end_bit, | ||
c10::cuda::getCurrentCUDAStream()); | ||
} else { | ||
CUB_WRAPPER( | ||
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys, | ||
keys_in_, | ||
keys_out_, | ||
n, | ||
begin_bit, | ||
end_bit, | ||
c10::cuda::getCurrentCUDAStream()); | ||
} | ||
} | ||
|
||
template <typename scalar_t> | ||
void unique( | ||
const scalar_t* input, | ||
scalar_t* output, | ||
int64_t* num_selected_out, | ||
int64_t num_items) { | ||
TORCH_CHECK( | ||
num_items <= std::numeric_limits<int>::max(), | ||
"cub unique does not support more than INT_MAX elements"); | ||
CUB_WRAPPER( | ||
NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique, | ||
input, | ||
output, | ||
num_selected_out, | ||
num_items, | ||
at::cuda::getCurrentCUDAStream()); | ||
} | ||
|
||
template <typename scalar_t> | ||
void run_length_encode( | ||
const scalar_t* input, | ||
scalar_t* output, | ||
int64_t* counts_out, | ||
int64_t* length_out, | ||
int64_t num_items) { | ||
TORCH_CHECK( | ||
num_items <= std::numeric_limits<int>::max(), | ||
"cub run_length_encode does not support more than INT_MAX elements"); | ||
CUB_WRAPPER( | ||
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode, | ||
input, | ||
output, | ||
counts_out, | ||
length_out, | ||
num_items, | ||
at::cuda::getCurrentCUDAStream()); | ||
} | ||
|
||
#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \ | ||
template void radix_sort_keys( \ | ||
const scalar_t* keys_in, \ | ||
scalar_t* keys_out, \ | ||
int64_t n, \ | ||
bool descending, \ | ||
int64_t begin_bit, \ | ||
int64_t end_bit); \ | ||
template void unique( \ | ||
const scalar_t* input, \ | ||
scalar_t* output, \ | ||
int64_t* num_selected_out, \ | ||
int64_t num_items); \ | ||
template void run_length_encode( \ | ||
const scalar_t* input, \ | ||
scalar_t* output, \ | ||
int64_t* counts_out, \ | ||
int64_t* length_out, \ | ||
int64_t n); | ||
|
||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES) | ||
|
||
} // namespace cub | ||
} // namespace cuda | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#define TORCH_ASSERT_NO_OPERATORS | ||
#include <ATen/cuda/CUDAConfig.h> | ||
#include <ATen/cuda/cub.cuh> | ||
|
||
namespace at { | ||
namespace cuda { | ||
namespace cub { | ||
namespace detail { | ||
|
||
template <typename key_t, int value_size> | ||
void radix_sort_pairs_impl( | ||
const key_t* keys_in, | ||
key_t* keys_out, | ||
const OpaqueType<value_size>* values_in, | ||
OpaqueType<value_size>* values_out, | ||
int64_t n, | ||
bool descending, | ||
int64_t begin_bit, | ||
int64_t end_bit) { | ||
TORCH_CHECK( | ||
n <= std::numeric_limits<int>::max(), | ||
"cub sort does not support sorting more than INT_MAX elements"); | ||
using key_t_ = typename detail::cuda_type<key_t>::type; | ||
|
||
auto allocator = c10::cuda::CUDACachingAllocator::get(); | ||
c10::DataPtr keys_out_owner; | ||
|
||
if (keys_out == nullptr) { | ||
keys_out_owner = allocator->allocate(n * sizeof(key_t)); | ||
keys_out = reinterpret_cast<key_t*>(keys_out_owner.get()); | ||
} | ||
|
||
const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in); | ||
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out); | ||
|
||
if (descending) { | ||
CUB_WRAPPER( | ||
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending, | ||
keys_in_, | ||
keys_out_, | ||
values_in, | ||
values_out, | ||
n, | ||
begin_bit, | ||
end_bit, | ||
c10::cuda::getCurrentCUDAStream()); | ||
} else { | ||
CUB_WRAPPER( | ||
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs, | ||
keys_in_, | ||
keys_out_, | ||
values_in, | ||
values_out, | ||
n, | ||
begin_bit, | ||
end_bit, | ||
c10::cuda::getCurrentCUDAStream()); | ||
} | ||
} | ||
|
||
#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \ | ||
template void radix_sort_pairs_impl( \ | ||
const key_t* keys_in, \ | ||
key_t* keys_out, \ | ||
const OpaqueType<value_size>* values_in, \ | ||
OpaqueType<value_size>* values_out, \ | ||
int64_t n, \ | ||
bool descending, \ | ||
int64_t begin_bit, \ | ||
int64_t end_bit); | ||
|
||
AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) | ||
AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) | ||
AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) | ||
AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) | ||
AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) | ||
AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) | ||
|
||
#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ | ||
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) | ||
|
||
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) | ||
|
||
// BFloat16 Radix sort is supported from ROCm 4.5 onwards | ||
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500) | ||
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) | ||
#endif | ||
|
||
} // namespace detail | ||
|
||
} // namespace cub | ||
} // namespace cuda | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.