diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fdd9c0f798..7c4d85fc33 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -742,6 +742,7 @@ Tensor {{ embedding_cuda_op }}( else { {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] { AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, @@ -753,9 +754,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -767,6 +770,7 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + }); } } @@ -775,6 +779,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -800,9 +805,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -1181,6 +1188,7 @@ Tensor {{ embedding_cuda_op }}( }); // DISPATCH_OPTIMAL_KERNEL }); // DISPATCH_EMB_GRAD_CACHE_TYPES + }); // AT_DISPATCH_INDEX_TYPES {%- if dense %} return grad_dev_weights;