Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2F/N) (pytorch#3376)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#623


X-link: facebookresearch/FBGEMM#467

- Add `index_t` support to TBE training backward kernels

Differential Revision: D65938455
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 7, 2025
1 parent 1573d07 commit 8bf5bbf
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t"}},
{%- endfor %}
Expand Down Expand Up @@ -90,7 +91,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_type_combo[ph_name].primitive_type }},
{%- endfor %}
Expand All @@ -381,7 +384,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for ph_type_combo in args.placeholder_type_combos %}
{{ template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
ph_type_combo,
kFixedMaxVecsPerThread,
kThreadGroupSize,
Expand All @@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}


Expand Down Expand Up @@ -533,6 +539,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread,
int32_t kThreadGroupSize,
bool kUseVecBlocking,
Expand All @@ -556,7 +563,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -652,6 +659,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
emb_t,
cache_t,
grad_t,
index_t,
BLOCK_SIZE,
embedding_dim,
segment_prefetch,
Expand Down Expand Up @@ -684,6 +692,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
emb_type,
grad_type,
cache_type,
index_type,
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
Expand All @@ -696,6 +705,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
< {{ emb_type }},
{{ grad_type }},
{{ cache_type }},
{{ index_type }},
{{ kFixedMaxVecsPerThread }},
{{ kThreadGroupSize }},
{{ kUseVecBlocking }},
Expand All @@ -718,7 +728,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -764,12 +774,14 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
{%- for emb_type in ['float', 'at::Half'] %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
grad_type,
cache_type,
index_type,
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
Expand All @@ -782,6 +794,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endfor %}
{%- endmacro %}

{%- macro hip_instantiate_templates(use_subwarp_shuffle) %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
{%- for ph_name in args.placeholder_tensor_names %}
typename {{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -167,7 +168,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -224,6 +225,7 @@ template <
typename emb_t,
typename grad_t,
typename cache_t,
typename index_t,
int32_t kFixedMaxVecsPerThread,
int32_t kThreadGroupSize,
bool kUseVecBlocking,
Expand All @@ -247,7 +249,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
int64_t D,
{%- endif %}
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
{%- if not nobag %}
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
Expand Down Expand Up @@ -826,8 +828,8 @@ Tensor {{ embedding_cuda_op }}(
AT_CUDA_CHECK(radix_sort_pairs(
nullptr,
temp_storage_bytes,
linear_indices.data_ptr<int64_t>(),
linear_indices_sorted.data_ptr<int64_t>(),
linear_indices.data_ptr<index_t>(),
linear_indices_sorted.data_ptr<index_t>(),
{{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(),
{{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(),
linear_indices.numel(),
Expand All @@ -842,8 +844,8 @@ Tensor {{ embedding_cuda_op }}(
AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
linear_indices.data_ptr<int64_t>(),
linear_indices_sorted.data_ptr<int64_t>(),
linear_indices.data_ptr<index_t>(),
linear_indices_sorted.data_ptr<index_t>(),
{{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(),
{{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(),
linear_indices.numel(),
Expand Down Expand Up @@ -888,8 +890,8 @@ Tensor {{ embedding_cuda_op }}(
AT_CUDA_CHECK(radix_sort_pairs(
nullptr,
temp_storage_bytes,
linear_indices.data_ptr<int64_t>(),
linear_indices_sorted.data_ptr<int64_t>(),
linear_indices.data_ptr<index_t>(),
linear_indices_sorted.data_ptr<index_t>(),
indice_weights.data_ptr<at::acc_type<cache_t, true>>(),
indice_weights_sorted.data_ptr<at::acc_type<cache_t, true>>(),
linear_indices.numel(),
Expand All @@ -904,8 +906,8 @@ Tensor {{ embedding_cuda_op }}(
AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
linear_indices.data_ptr<int64_t>(),
linear_indices_sorted.data_ptr<int64_t>(),
linear_indices.data_ptr<index_t>(),
linear_indices_sorted.data_ptr<index_t>(),
indice_weights.data_ptr<at::acc_type<cache_t, true>>(),
indice_weights_sorted.data_ptr<at::acc_type<cache_t, true>>(),
linear_indices.numel(),
Expand Down Expand Up @@ -1174,6 +1176,7 @@ Tensor {{ embedding_cuda_op }}(
<emb_t,
grad_t,
cache_t,
index_t,
{%- for ph_name in args.placeholder_tensor_names %}
{{ ph_name + "_ph_t" }},
{%- endfor %}
Expand Down Expand Up @@ -1225,6 +1228,7 @@ Tensor {{ embedding_cuda_op }}(
<emb_t,
grad_t,
cache_t,
index_t,
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
Expand Down Expand Up @@ -1264,7 +1268,7 @@ Tensor {{ embedding_cuda_op }}(
D,
{%- endif %}
MAKE_PTA_WITH_NAME(func_name4, hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_run, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_run, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32),
{%- if not nobag %}
MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int32_t, 1, 32),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ template <typename optimizer_t,
typename emb_t,
typename cache_t,
typename grad_t,
typename index_t,
int32_t block_size,
int32_t embedding_dim,
int32_t segment_prefetch, // 2
Expand All @@ -118,7 +119,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
const grad_t* p_output_grad,
emb_t* p_emb_table,
const int64_t* p_hash_size_cumsum,
const int64_t* p_sorted_linear_indices_run,
const index_t* p_sorted_linear_indices_run,
const int32_t* p_sorted_linear_indices_cumulative_run_lengths,
const int32_t* p_sorted_linear_indices_num_runs,
{%- if not nobag %}
Expand Down Expand Up @@ -151,7 +152,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
return;
}

const int64_t linear_index = p_sorted_linear_indices_run[run_id];
const auto linear_index = p_sorted_linear_indices_run[run_id];

const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id];
const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1];
Expand Down Expand Up @@ -458,4 +459,4 @@ L_tail_grad_acc:

store_row_per_warp<emb_t, embedding_dim, emb_t>::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id);
}
} // namespace fbgemm_gpu::rocm
} // namespace fbgemm_gpu::rocm

0 comments on commit 8bf5bbf

Please sign in to comment.