From 59ae33c212bf627809cd5fc0cfa8ca0333cf18a4 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Wed, 22 Jan 2025 10:33:36 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2H/N) (#3539) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/626 - Add `int32_t` indices support for the TBE CPU kernels Reviewed By: sryap Differential Revision: D67784746 --- ...ing_backward_split_cpu_approx_template.cpp | 151 ++++++++++-------- .../embedding_backward_split_cpu_template.cpp | 99 +++++++----- ...dding_backward_split_host_cpu_template.cpp | 1 + .../forward/embedding_forward_split_cpu.cpp | 56 ++++--- .../utils/embedding_bounds_check_host_cpu.cpp | 14 +- 5 files changed, 195 insertions(+), 126 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp index 615b9288de..2069d13048 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp @@ -19,11 +19,17 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/utils/dispatch_macros.h" +#if FBGEMM_GPU_MEMCHECK +#define FBGEMM_MEM_CHECK_ONLY +#else +#define FBGEMM_MEM_CHECK_ONLY maybe_unused +#endif + using Tensor = at::Tensor; using namespace fbgemm_gpu; namespace { -template +template void split_embedding_backward_approx_cpu_kernel( Tensor grad_output, Tensor host_weights, @@ -44,8 +50,11 @@ void split_embedding_backward_approx_cpu_kernel( {{ args.split_cpu_kernel_args | join(", ") }}) { auto grad_output_data = grad_output.accessor(); auto host_weights_data = host_weights.accessor(); - const auto indices_data = indices.accessor(); - const auto offsets_data = offsets.accessor(); + + [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_approx_cpu_kernel"; + const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1); + const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1); + // If indice_weights are not defined, then this accessor won't be used auto indice_weights_data = indice_weights.defined() ? indice_weights.accessor, 1>() @@ -133,75 +142,84 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu( !indice_weights.defined() && static_cast(pooling_mode) == PoolingMode::SUM; if (use_fbgemm) { - auto grad_stride = grad_output.size(1); - const float* grad_output_data = grad_output.data_ptr(); - float* host_weights_data = host_weights.data_ptr(); - const int64_t* indices_data = indices.data_ptr(); - const int64_t* offsets_data = offsets.data_ptr(); - const auto hash_size_cumsum_data = hash_size_cumsum.accessor(); - float* momentum1_data = momentum1_host.data_ptr(); - - at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) { - int t_begin = tb_begin / B; - int t_end = (tb_end + B - 1) / B; -for (const auto t : c10::irange(t_begin,t_end)) { - auto D_begin = D_offsets_data[t]; - auto D = D_offsets_data[t + 1] - D_offsets_data[t]; - auto table_begin = weights_offsets_data[t]; - auto momentum_begin = momentum1_offsets_data[t]; - - int64_t hash_size; - int t_temp = t + 1; - do { - hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t]; - ++t_temp; - } while (hash_size == 0); - - int b_begin = (t == t_begin) ? tb_begin % B : 0; - int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B; - - auto kernel = - fbgemm::GenerateRowWiseSparseAdaGradFused( - D, - /*prefetch=*/16, - /*use_offsets=*/true, - /*use_stochastic_round=*/true, - /*grad_stride=*/grad_stride); - auto offsets_begin_ptr = offsets_data + t * B + b_begin; - auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr; - bool success = kernel( - b_end - b_begin, - index_size, - hash_size, - reinterpret_cast(host_weights_data + table_begin), - reinterpret_cast( - grad_output_data + b_begin * grad_stride + D_begin), - reinterpret_cast(momentum1_data + momentum_begin), - indices_data + *offsets_begin_ptr, - offsets_begin_ptr, - eps, - // fbgemm follows caffe2 convention of negative learning rate - -learning_rate); - - if (!success) { - fbgemm_gpu::report_embedding_error( - t, B, b_begin, b_end, offsets_data, indices_data, hash_size); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] { + + auto grad_stride = grad_output.size(1); + const float* grad_output_data = grad_output.data_ptr(); + float* host_weights_data = host_weights.data_ptr(); + + const auto* indices_data = indices.data_ptr(); + const auto* offsets_data = offsets.data_ptr(); + + const auto hash_size_cumsum_data = hash_size_cumsum.accessor(); + float* momentum1_data = momentum1_host.data_ptr(); + + at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) { + int t_begin = tb_begin / B; + int t_end = (tb_end + B - 1) / B; + + for (const auto t : c10::irange(t_begin,t_end)) { + auto D_begin = D_offsets_data[t]; + auto D = D_offsets_data[t + 1] - D_offsets_data[t]; + auto table_begin = weights_offsets_data[t]; + auto momentum_begin = momentum1_offsets_data[t]; + + int64_t hash_size; + int t_temp = t + 1; + do { + hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t]; + ++t_temp; + } while (hash_size == 0); + + int b_begin = (t == t_begin) ? tb_begin % B : 0; + int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B; + + auto kernel = + fbgemm::GenerateRowWiseSparseAdaGradFused( + D, + /*prefetch=*/16, + /*use_offsets=*/true, + /*use_stochastic_round=*/true, + /*grad_stride=*/grad_stride); + auto offsets_begin_ptr = offsets_data + t * B + b_begin; + auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr; + bool success = kernel( + b_end - b_begin, + index_size, + hash_size, + reinterpret_cast(host_weights_data + table_begin), + reinterpret_cast( + grad_output_data + b_begin * grad_stride + D_begin), + reinterpret_cast(momentum1_data + momentum_begin), + indices_data + *offsets_begin_ptr, + offsets_begin_ptr, + eps, + // fbgemm follows caffe2 convention of negative learning rate + -learning_rate); + + if (!success) { + fbgemm_gpu::report_embedding_error( + t, B, b_begin, b_end, offsets_data, indices_data, hash_size); + } } - } - }); // parallel_for + }); // parallel_for + }); // dispatch indices.scalar_type() + return; } // use_fbgemm {% endif %} - FBGEMM_DISPATCH_FLOAT_AND_HALF( - grad_output.scalar_type(), "split_embedding_backward_cpu", [&] { + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] { + + FBGEMM_DISPATCH_FLOAT_AND_HALF( + grad_output.scalar_type(), "split_embedding_backward_approx_cpu_kernel_2", [&] { using grad_t = scalar_t; - FBGEMM_DISPATCH_FLOAT_AND_HALF( - host_weights.scalar_type(), - "split_embedding_backward_cpu_inner", - [&] { - split_embedding_backward_approx_cpu_kernel( + + FBGEMM_DISPATCH_FLOAT_AND_HALF( + host_weights.scalar_type(), "split_embedding_backward_approx_cpu_kernel_3", [&] { + split_embedding_backward_approx_cpu_kernel( grad_output, host_weights, weights_offsets_data, @@ -220,7 +238,8 @@ for (const auto t : c10::irange(t_begin,t_end)) { {% endif %} {{ args.split_cpu_kernel_arg_constructors | join(", ") }}); }); // dispatch host_weights.scalar_type() - }); // dispatch grad_output.scalar_type() + }); // dispatch grad_output.scalar_type() + }); // dispatch indices.scalar_type() return; } diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp index 3f6095c955..8e1f9d7852 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp @@ -24,6 +24,12 @@ #include "fbgemm_gpu/utils/cpu_utils.h" #include "fbgemm_gpu/utils/ops_utils.h" +#if FBGEMM_GPU_MEMCHECK +#define FBGEMM_MEM_CHECK_ONLY +#else +#define FBGEMM_MEM_CHECK_ONLY maybe_unused +#endif + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -40,7 +46,7 @@ struct half2float16 { } // namespace internal namespace { -template +template void split_embedding_backward_exact_cpu_kernel( Tensor grad_output, Tensor host_weights, @@ -94,8 +100,8 @@ for (const auto t : c10::irange(num_tables)) { ::internal::csr2csc( cscs[t], B, - MAKE_TA_WITH_NAME(func_name, offsets, int64_t, 1), - MAKE_TA_WITH_NAME(func_name, indices, int64_t, 1), + MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1), + MAKE_TA_WITH_NAME(func_name, indices, index_t, 1), MAKE_TA_WITH_NAME(func_name, indice_weights, weight_t, 1), pooling_mode, table_to_feature_offset + t, @@ -196,19 +202,21 @@ for (const auto t : c10::irange(num_tables)) { // TODO: to parallelize, we should easily identify segments belong to // the same column. at::acc_type grad_buffer[D]; -for (const auto c : c10::irange(num_non_zero_columns)) { + for (const auto c : c10::irange(num_non_zero_columns)) { int64_t idx = col_segment_indices[c]; if (c == 0 || col_segment_indices[c - 1] != idx) { memset(grad_buffer, 0, D * sizeof(at::acc_type)); } [[maybe_unused]] const int64_t embedding_begin = table_begin + idx * D; + for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) { int D_offset = D_begin; if (is_shared_table) { D_offset += cscs[t].column_segment_ids[r] * D; } int b = cscs[t].row_indices[r]; -for (const auto d : c10::irange(D)) { + + for (const auto d : c10::irange(D)) { if (cscs[t].weights != nullptr) { grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] * cscs[t].weights[r]; @@ -225,7 +233,7 @@ for (const auto d : c10::irange(D)) { } // for each table } -template +template void split_embedding_backward_exact_cpu_dense_kernel( Tensor grad, Tensor grad_output, @@ -242,8 +250,10 @@ void split_embedding_backward_exact_cpu_dense_kernel( auto grad_output_data = grad_output.accessor(); - const auto indices_data = indices.accessor(); - const auto offsets_data = offsets.accessor(); + [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_exact_cpu_dense_kernel"; + + const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1); + const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1); const auto indice_weights_data = indice_weights.defined() ? // If indice_weights are not defined, then this accessor won't be @@ -349,34 +359,41 @@ for (const auto d : c10::irange(D)) { grad_output = grad_output.contiguous(); - - FBGEMM_DISPATCH_FLOAT_AND_HALF( + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "split_embedding_backward_exact_cpu_kernel_1", [&] { + + FBGEMM_DISPATCH_FLOAT_AND_HALF( grad_output.scalar_type(), - "split_embedding_backward_exact_cpu_outer", [&]() { - using grad_t = scalar_t; + "split_embedding_backward_exact_cpu_kernel_2", [&] { + using grad_t = scalar_t; + FBGEMM_DISPATCH_FLOAT_AND_HALF( - host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] { - split_embedding_backward_exact_cpu_kernel( - grad_output, - host_weights, - weights_offsets_data, - D_offsets_data, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - num_tables, - B, - table_to_feature_offset, - {% if "momentum1_offsets" in args.split_function_arg_names %} - momentum1_offsets_data, - {% endif %} - {% if "momentum2_offsets" in args.split_function_arg_names %} - momentum2_offsets_data, - {% endif %} - {{ args.split_cpu_kernel_arg_constructors | join(", ") }}); - }); + host_weights.scalar_type(), + "split_embedding_backward_exact_cpu_kernel_3", [&] { + + split_embedding_backward_exact_cpu_kernel( + grad_output, + host_weights, + weights_offsets_data, + D_offsets_data, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + num_tables, + B, + table_to_feature_offset, + {% if "momentum1_offsets" in args.split_function_arg_names %} + momentum1_offsets_data, + {% endif %} + {% if "momentum2_offsets" in args.split_function_arg_names %} + momentum2_offsets_data, + {% endif %} + {{ args.split_cpu_kernel_arg_constructors | join(", ") }}); + }); + }); }); return; @@ -385,10 +402,15 @@ for (const auto d : c10::irange(D)) { // When input is dense enough, avoid sorting and just treat as dense. auto grad = zeros_like(host_weights, grad_output.dtype()); - FBGEMM_DISPATCH_FLOAT_AND_HALF( - grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] { + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "split_embedding_backward_exact_cpu_dense_kernel", [&] { - split_embedding_backward_exact_cpu_dense_kernel( + FBGEMM_DISPATCH_FLOAT_AND_HALF( + grad_output.scalar_type(), + "split_embedding_backward_exact_cpu", [&] { + + split_embedding_backward_exact_cpu_dense_kernel( grad, grad_output, weights_offsets_data, @@ -400,7 +422,8 @@ for (const auto d : c10::irange(D)) { num_tables, B, table_to_feature_offset); - }); // dispatch host_weights.scalar_type() + }); + }); return grad; {% endif %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp index fc7d8a58f3..3dde059b37 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp @@ -158,6 +158,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< stochastic_rounding, {{ args.split_function_arg_names | join(", ") }}, output_dtype); + static auto op2 = torch::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "") diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp index 5117583415..104ed9590f 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp @@ -26,6 +26,12 @@ #include #include +#if FBGEMM_GPU_MEMCHECK +#define FBGEMM_MEM_CHECK_ONLY +#else +#define FBGEMM_MEM_CHECK_ONLY maybe_unused +#endif + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -287,7 +293,7 @@ Tensor split_embedding_codegen_forward_cpu_meta( return output; } -template +template void split_embedding_grad_indice_weights_cpu_kernel( Tensor grad_output, Tensor weights, @@ -305,8 +311,11 @@ void split_embedding_grad_indice_weights_cpu_kernel( const auto D_offsets_data = D_offsets.accessor(); const auto weights_offsets_data = weights_offsets.accessor(); - const auto offsets_data = offsets.accessor(); - const auto indices_data = indices.accessor(); + + [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = + "split_embedding_grad_indice_weights_cpu_kernel"; + const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1); + const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1); const auto weights_data = weights.accessor(); const auto grad_output_data = grad_output.accessor(); @@ -352,25 +361,34 @@ Tensor split_embedding_codegen_grad_indice_weights_cpu( indices, indices.options().dtype( at::toAccumulateType(grad_output.scalar_type(), true))); - FBGEMM_DISPATCH_FLOAT_AND_HALF( - grad_output.scalar_type(), - "split_embedding_grad_indice_weights_cpu_outer", + + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), + "split_embedding_grad_indice_weights_cpu_kernel_1", [&] { - using grad_t = scalar_t; FBGEMM_DISPATCH_FLOAT_AND_HALF( - weights.scalar_type(), - "split_embedding_grad_indice_weights_cpu", + grad_output.scalar_type(), + "split_embedding_grad_indice_weights_cpu_kernel_2", [&] { - using weights_t = scalar_t; - split_embedding_grad_indice_weights_cpu_kernel( - grad_output, - weights, - weights_offsets, - D_offsets, - indices, - offsets, - feature_requires_grad, - grad_indice_weights); + using grad_t = scalar_t; + FBGEMM_DISPATCH_FLOAT_AND_HALF( + weights.scalar_type(), + "split_embedding_grad_indice_weights_cpu_kernel_3", + [&] { + using weights_t = scalar_t; + split_embedding_grad_indice_weights_cpu_kernel< + index_t, + weights_t, + grad_t>( + grad_output, + weights, + weights_offsets, + D_offsets, + indices, + offsets, + feature_requires_grad, + grad_indice_weights); + }); }); }); diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 4a6bf5cc93..2316870f96 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -13,8 +13,15 @@ #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/utils/ops_utils.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" #include "fbgemm_gpu/utils/tensor_utils.h" +#if FBGEMM_GPU_MEMCHECK +#define FBGEMM_MEM_CHECK_ONLY +#else +#define FBGEMM_MEM_CHECK_ONLY maybe_unused +#endif + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -75,9 +82,10 @@ void bounds_check_indices_cpu( auto warning_acc = warning.data_ptr(); AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] { - auto offsets_acc = offsets.accessor(); - auto indices_acc = indices.accessor(); - auto num_indices = indices.numel(); + [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "bounds_check_indices_cpu"; + auto indices_acc = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1); + auto offsets_acc = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1); + const auto num_indices = indices.numel(); if (vbe) { TORCH_CHECK(max_B >= 0);