From af8d89d4bedfd2159d91dbdd558c9c73b7010193 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 14 Nov 2024 13:02:11 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2C/N) (#3372) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/465 - Add `index_t` support to TBE training backward kernels Differential Revision: D65925354 --- .../backward/embedding_backward_split_template.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fdd9c0f798..7c4d85fc33 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -742,6 +742,7 @@ Tensor {{ embedding_cuda_op }}( else { {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] { AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, @@ -753,9 +754,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -767,6 +770,7 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + }); } } @@ -775,6 +779,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -800,9 +805,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -1181,6 +1188,7 @@ Tensor {{ embedding_cuda_op }}( }); // DISPATCH_OPTIMAL_KERNEL }); // DISPATCH_EMB_GRAD_CACHE_TYPES + }); // AT_DISPATCH_INDEX_TYPES {%- if dense %} return grad_dev_weights;