diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 0e4f552ebc..1732239db2 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -62,6 +62,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t"}}, {%- endfor %} @@ -90,7 +91,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {%- for ph_name in args.placeholder_tensor_names %} {{ ph_type_combo[ph_name].primitive_type }}, {%- endfor %} @@ -381,7 +384,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for ph_type_combo in args.placeholder_type_combos %} {{ template_instantiation( emb_type, grad_type, cache_type, + index_type, ph_type_combo, kFixedMaxVecsPerThread, kThreadGroupSize, @@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} @@ -533,6 +539,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, int32_t kFixedMaxVecsPerThread, int32_t kThreadGroupSize, bool kUseVecBlocking, @@ -556,7 +563,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -652,6 +659,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd emb_t, cache_t, grad_t, + index_t, BLOCK_SIZE, embedding_dim, segment_prefetch, @@ -684,6 +692,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd emb_type, grad_type, cache_type, + index_type, kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking, @@ -696,6 +705,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, + {{ index_type }}, {{ kFixedMaxVecsPerThread }}, {{ kThreadGroupSize }}, {{ kUseVecBlocking }}, @@ -718,7 +728,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -764,12 +774,14 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %} {%- for emb_type in ['float', 'at::Half'] %} {%- for cache_type in ['float', 'at::Half'] %} + {%- for index_type in ['int32_t', 'int64_t'] %} {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, grad_type, cache_type, + index_type, kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking, @@ -782,6 +794,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- endfor %} {%- endfor %} {%- endfor %} + {%- endfor %} {%- endmacro %} {%- macro hip_instantiate_templates(use_subwarp_shuffle) %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 03649af68f..46593e2ce7 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -139,6 +139,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, {%- for ph_name in args.placeholder_tensor_names %} typename {{ ph_name + "_ph_t" }}, {%- endfor %} @@ -167,7 +168,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -224,6 +225,7 @@ template < typename emb_t, typename grad_t, typename cache_t, + typename index_t, int32_t kFixedMaxVecsPerThread, int32_t kThreadGroupSize, bool kUseVecBlocking, @@ -247,7 +249,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd int64_t D, {%- endif %} const pta::PackedTensorAccessor32 hash_size_cumsum, - const pta::PackedTensorAccessor32 sorted_linear_indices_run, + const pta::PackedTensorAccessor32 sorted_linear_indices_run, const pta::PackedTensorAccessor32 sorted_linear_indices_cumulative_run_lengths, {%- if not nobag %} const pta::PackedTensorAccessor32 sorted_infos, @@ -826,8 +828,8 @@ Tensor {{ embedding_cuda_op }}( AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), linear_indices.numel(), @@ -842,8 +844,8 @@ Tensor {{ embedding_cuda_op }}( AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), {{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(), {{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(), linear_indices.numel(), @@ -888,8 +890,8 @@ Tensor {{ embedding_cuda_op }}( AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), indice_weights.data_ptr>(), indice_weights_sorted.data_ptr>(), linear_indices.numel(), @@ -904,8 +906,8 @@ Tensor {{ embedding_cuda_op }}( AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, - linear_indices.data_ptr(), - linear_indices_sorted.data_ptr(), + linear_indices.data_ptr(), + linear_indices_sorted.data_ptr(), indice_weights.data_ptr>(), indice_weights_sorted.data_ptr>(), linear_indices.numel(), @@ -1174,6 +1176,7 @@ Tensor {{ embedding_cuda_op }}( ::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } -} // namespace fbgemm_gpu::rocm \ No newline at end of file +} // namespace fbgemm_gpu::rocm