Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2E/N) (pytorch#3375)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#621

X-link: facebookresearch/FBGEMM#466


- Add `index_t` support to TBE training backward kernels

Differential Revision: D65933410
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 3, 2025
1 parent 617a434 commit 81ab9bb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -105,7 +106,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -430,6 +431,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -446,6 +448,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_type_combo[ph_name].primitive_type }},
{%- endfor %}
Expand All @@ -470,7 +473,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -538,11 +541,13 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for ph_type_combo in args.placeholder_type_combos %}
{{ template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -552,6 +557,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -77,7 +78,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> long_run_ids,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> num_long_run_ids,
Expand Down Expand Up @@ -1052,6 +1053,7 @@ Tensor {{ embedding_cuda_op }}(
<emb_t,
grad_t,
cache_t,
index_t,
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -1101,7 +1103,7 @@ Tensor {{ embedding_cuda_op }}(
D,
{%- endif %}
MAKE_PTA_WITH_NAME(func_name3, hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_run, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_run, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name3, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name3, long_run_ids, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name3, num_long_run_ids, int32_t, 1, 32),
Expand Down

0 comments on commit 81ab9bb

Please sign in to comment.