From bb0e66474992980fbf8b54013650a8f5cb57da39 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 2 Jan 2025 21:47:15 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2C/N) (#3372) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/619 X-link: https://github.com/facebookresearch/FBGEMM/pull/465 - Add `index_t` support to TBE training backward kernels Differential Revision: D65925354 --- .../backward/embedding_backward_split_template.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 5029a382a..3199a1b00 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -821,6 +821,7 @@ Tensor {{ embedding_cuda_op }}( else { {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] { AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, @@ -832,9 +833,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -846,6 +849,7 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + }); } } @@ -865,6 +869,7 @@ Tensor {{ embedding_cuda_op }}( %} {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -890,9 +895,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -1308,6 +1315,7 @@ Tensor {{ embedding_cuda_op }}( }); // DISPATCH_OPTIMAL_KERNEL }); // DISPATCH_EMB_GRAD_CACHE_TYPES + }); // AT_DISPATCH_INDEX_TYPES {%- if dense %} return grad_dev_weights;