Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2G/N) (pytorch#3377)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#622


X-link: facebookresearch/FBGEMM#468

- Add `index_t` support to TBE training backward kernels

Reviewed By: basilwong

Differential Revision: D65960050
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 9, 2025
1 parent 3c792d6 commit bac1b7d
Showing 1 changed file with 21 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread
>
__global__ __launch_bounds__(kForwardMaxThreads) void
Expand All @@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
{%- if not dense %}
const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }},
{%- endif %}
Expand Down Expand Up @@ -113,17 +114,17 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
fd_B.DivMod(b_t, &t, &b);
{%- endif %}

int64_t weights_offset = weights_offsets[t];
int32_t D_start = D_offsets[t];
int32_t D_end = D_offsets[t + 1];
int32_t D = D_end - D_start;
int64_t indices_start = offsets[b_t];
int64_t indices_end = offsets[b_t + 1];
int32_t L = indices_end - indices_start;
const auto weights_offset = weights_offsets[t];
const auto D_start = D_offsets[t];
const auto D_end = D_offsets[t + 1];
const auto D = D_end - D_start;
const auto indices_start = offsets[b_t];
const auto indices_end = offsets[b_t + 1];
const auto L = indices_end - indices_start;
if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) {
// If the table does not require gradient computation, we set the gradient to zero.
for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
for (auto l_start = 0; l_start < L; l_start += kWarpSize) {
auto l = l_start + threadIdx.x;
if (l < L) {
grad_indice_weights[indices_start + l] = 0.0;
}
Expand Down Expand Up @@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void

for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
int32_t l = l_start + threadIdx.x;
int64_t idx = l < L ? indices[indices_start + l] : 0;
index_t idx = l < L ? indices[indices_start + l] : 0;
{%- if not dense %}
const auto {{ locs_or_addrs_idx }} =
(placement == PlacementType::MANAGED_CACHING && l < L)
? {{ locs_or_addrs_tensor }}[indices_start + l] : 0;
{%- endif %}
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
int64_t idx_j = shfl_sync(idx, j);
auto idx_j = shfl_sync(idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
{%- endif %}
Expand Down Expand Up @@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
const uint32_t info_B_mask = info_B_mask_int64;
{%- endif %}

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
Expand All @@ -362,7 +364,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
{%- else %}
dev_weights.scalar_type(),
{%- endif %}
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel",
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2",
[&] {
{%- if vbe %}
const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1});
Expand All @@ -379,13 +381,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
mdesc, vdesc, vbdesc)
%}
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"{{ kernel_name }}";
const auto func_name = "{{ kernel_name }}";
#endif
{{ kernel_name }}<
emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread><<<
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
Expand All @@ -400,8 +402,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
{%- endif %}
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
{%- if not dense %}
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
{%- endif %}
Expand All @@ -421,6 +423,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
});
{%- endfor %} {# /* for use_vec_blocking */ #}
});
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_indice_weights;
Expand Down

0 comments on commit bac1b7d

Please sign in to comment.