Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2D/N) (pytorch#3374)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#464


- Add `index_t` support to TBE training backward kernels

Differential Revision: D65930273
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 14, 2024
1 parent a22ca25 commit 019dc22
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ void split_embedding_backward_count_unique_indices_kernel

{% for vbe in [True, False] %}
{% set vdesc = "_vbe" if vbe else "" %}
template <typename grad_t>
template <typename grad_t, typename offset_t>
__global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand Down Expand Up @@ -212,15 +212,16 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vdesc }}_kernel(
////////////////////////////////////////////////////////////////////////////////

{% for grad_type in ['at::Half', 'float', 'at::BFloat16'] %}
{% for offset_type in ['int32_t', 'int64_t'] %}
template __global__ __launch_bounds__(kMaxThreads)
void grad_mean{{ vdesc }}_kernel
<{{ grad_type }}> (
<{{ grad_type }}, {{ offset_type }}> (
pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output_mean,
const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits>
grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<{{ offset_type }}, 1, at::RestrictPtrTraits> offsets,
{% if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> row_grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand All @@ -230,6 +231,7 @@ void grad_mean{{ vdesc }}_kernel
FixedDivisor fd_B
{% endif %}
);
{% endfor %} // for offset_type in ['int32_t', 'int64_t']
{% endfor %} // for grad_type in ['at::Half', 'float']
{% endfor %} // for vbe in [True, False]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ split_embedding_backward_codegen_find_long_segments(
const bool use_deterministic_algorithms);


template <typename grad_t>
template <typename grad_t, typename offset_t>
__global__ __launch_bounds__(kMaxThreads) void
grad_mean{{ vdesc }}_kernel(
pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits> grad_output_mean,
const pta::PackedTensorAccessor64<grad_t, 2, at::RestrictPtrTraits> grad_output,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
const pta::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits> offsets,
{%- if vbe %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> b_t_map,
Expand Down Expand Up @@ -860,7 +860,7 @@ Tensor {{ embedding_cuda_op }}(
MAKE_PTA_WITH_NAME(func_name1, grad_output_mean, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name1, grad_output_reshaped, grad_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32),
{%- if vbe %}
MAKE_PTA_WITH_NAME(func_name1, vbe_row_output_offsets, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, vbe_b_t_map, int32_t, 1, 32),
Expand Down

0 comments on commit 019dc22

Please sign in to comment.