From 00bb1db32e26d1b35770a6b66bb814ef7e4c6d30 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 2 Jan 2025 12:36:47 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2G/N) (#3377) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/468 - Add `index_t` support to TBE training backward kernels Differential Revision: D65960050 --- ..._backward_split_indice_weights_template.cu | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 8e1db36757..5c4e783d32 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -64,6 +64,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, int32_t kFixedMaxVecsPerThread > __global__ __launch_bounds__(kForwardMaxThreads) void @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, const pta::PackedTensorAccessor32 D_offsets, - const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] - const pta::PackedTensorAccessor32 offsets, // [B x T + 1] + const pta::PackedTensorAccessor32 indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L] + const pta::PackedTensorAccessor32 offsets, // [B x T + 1] {%- if not dense %} const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }}, {%- endif %} @@ -113,17 +114,17 @@ __global__ __launch_bounds__(kForwardMaxThreads) void fd_B.DivMod(b_t, &t, &b); {%- endif %} - int64_t weights_offset = weights_offsets[t]; - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - int64_t indices_start = offsets[b_t]; - int64_t indices_end = offsets[b_t + 1]; - int32_t L = indices_end - indices_start; + const auto weights_offset = weights_offsets[t]; + const auto D_start = D_offsets[t]; + const auto D_end = D_offsets[t + 1]; + const auto D = D_end - D_start; + const auto indices_start = offsets[b_t]; + const auto indices_end = offsets[b_t + 1]; + const auto L = indices_end - indices_start; if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) { // If the table does not require gradient computation, we set the gradient to zero. - for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { - int32_t l = l_start + threadIdx.x; + for (auto l_start = 0; l_start < L; l_start += kWarpSize) { + auto l = l_start + threadIdx.x; if (l < L) { grad_indice_weights[indices_start + l] = 0.0; } @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) { int32_t l = l_start + threadIdx.x; - int64_t idx = l < L ? indices[indices_start + l] : 0; + index_t idx = l < L ? indices[indices_start + l] : 0; {%- if not dense %} const auto {{ locs_or_addrs_idx }} = (placement == PlacementType::MANAGED_CACHING && l < L) ? {{ locs_or_addrs_tensor }}[indices_start + l] : 0; {%- endif %} for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { - int64_t idx_j = shfl_sync(idx, j); + auto idx_j = shfl_sync(idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); {%- endif %} @@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( const uint32_t info_B_mask = info_B_mask_int64; {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -362,7 +364,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- else %} dev_weights.scalar_type(), {%- endif %} - "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel", + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2", [&] { {%- if vbe %} const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1}); @@ -379,13 +381,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( mdesc, vdesc, vbdesc) %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "{{ kernel_name }}"; + const auto func_name = "{{ kernel_name }}"; #endif {{ kernel_name }}< emb_t, grad_t, cache_t, + index_t, kFixedMaxVecsPerThread><<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -400,8 +402,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( {%- endif %} MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32), {%- endif %} @@ -421,6 +423,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( }); {%- endfor %} {# /* for use_vec_blocking */ #} }); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_indice_weights;