From 81ab9bbca82ef769dabd7c6f06ffe0c152762ffc Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 2 Jan 2025 21:47:15 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2E/N) (#3375) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/621 X-link: https://github.com/facebookresearch/FBGEMM/pull/466 - Add `index_t` support to TBE training backward kernels Differential Revision: D65933410 --- .../embedding_backward_split_kernel_cta_template.cu | 10 ++++++++-- .../backward/embedding_backward_split_template.cu | 6 ++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index 3fb49ed5e7..1cfeb66c94 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -77,6 +77,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -105,7 +106,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {%- for ph_name in args.placeholder_tensor_names %} {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} @@ -470,7 +473,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for ph_type_combo in args.placeholder_type_combos %} {{ template_instantiation( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 2f661a633c..03649af68f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -49,6 +49,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -77,7 +78,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, const pta::PackedTensorAccessor32 long_run_ids, const pta::PackedTensorAccessor32 num_long_run_ids, @@ -1052,6 +1053,7 @@ Tensor {{ embedding_cuda_op }}(