From 555d40f2d75be4ee2a12fab5a43f8f26ce4891f9 Mon Sep 17 00:00:00 2001 From: Supadchaya Puangpontip Date: Wed, 25 Sep 2024 14:38:44 -0700 Subject: [PATCH] Enable VBE support on CPU Summary: Previous VBE on CPU was enabled in lookup_{{ optimizer }}.py. To support MTIA ops, VBE should be done after torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2. This diff follows the same implementation but enables it C++ so that it goes through the same PT2 pipeline (i.e., lookup -> VBE autograd -> cpu wrapper (*do vbe here*) -> cpu kernel). the call is done Differential Revision: D63410944 --- .../genscript/generate_forward_split.py | 1 + ...dding_split_host_pt2_autograd_template.cpp | 31 +++-- ...ng_split_host_pt2_cpu_wrapper_template.cpp | 127 +++++++++++++----- .../training/pt2/pt2_autograd_utils.cpp | 68 ++++++++++ .../fbgemm_gpu/utils/pt2_autograd_utils.h | 31 +++++ .../tbe/training/backward_adagrad_test.py | 27 ++++ .../test/tbe/training/backward_sgd_test.py | 60 +++++++++ .../test/tbe/training/failures_dict_fast.json | 6 + 8 files changed, 311 insertions(+), 40 deletions(-) create mode 100644 fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp create mode 100644 fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 285cf9a559..894ce104ce 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None: f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp", has_cpu_support=True, is_forward=True, + has_vbe_support=True, ) # Generate PT2 forward wrapper (CUDA) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b623b92d03..b09b4e1f47 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -37,6 +37,9 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" #include "fbgemm_gpu/config/feature_gates.h" +{%- if has_vbe_support %} +#include "fbgemm_gpu/utils/pt2_autograd_utils.h" +{%- endif %} using Tensor = at::Tensor; @@ -236,9 +239,9 @@ enum SSDTensor { const Tensor& /*prev_iter_dev*/, {%- endif %} {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, + const int64_t /*iter*/, {%- endif %} - const double gwd_lower_bound, + const double /*gwd_lower_bound*/, {%- endif %} {# /* if is_gwd */ #} {%- for arg_type in args_pt2.split_function_args %} {{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %} @@ -617,11 +620,15 @@ class {{ autograd_func }} : const c10::SymInt, const int64_t, const c10::SymInt)>(); - - auto [ - vbe_row_output_offsets, - vbe_b_t_map - ] = generate_vbe_metadata_op.call( + Tensor vbe_row_output_offsets, vbe_b_t_map; + if (weights_host.numel() > 0){ + // generate_vbe_metadata_op is not implemented for CPU + // TODO: implement CPU version of generate_vbe_metadata_op and remove this branch + vbe_row_output_offsets = vbe_output_offsets_feature_rank_; + vbe_b_t_map = vbe_B_offsets_rank_per_feature_; + } + else{ + std::tie(vbe_row_output_offsets, vbe_b_t_map) = generate_vbe_metadata_op.call( B_offsets_, vbe_B_offsets_rank_per_feature_, vbe_output_offsets_feature_rank_, @@ -639,6 +646,7 @@ class {{ autograd_func }} : info_B_num_bits, /*total_B=*/offsets.sym_size(0) - 1 ); + } {%- endif %} // vbe {%- if is_gwd %} @@ -850,6 +858,11 @@ static torch::autograd::variable_list backward( // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) weights_dev = weights_dev.flatten(); {%- endif %} + {%- if vbe %} + if (weights_host.numel() > 1){ + grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets); + } + {%- endif %} {%- set grad_indice_weights_op = "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) @@ -883,7 +896,7 @@ static torch::autograd::variable_list backward( {%- else %} const Tensor& /*feature_requires_grad*/ {%- endif %} - )>(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : @@ -1014,7 +1027,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if not ssd %} {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + if (B_offsets.has_value()) { {%- if has_global_weight_decay_support %} // vbe and has gwd support if (apply_global_weight_decay && weight_decay > 0) { diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index c74355207f..5b2b066fee 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -30,9 +30,12 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{%- for vbe in ([True, False] if has_vbe_support else [False]) %} +{%- set vdesc = "_vbe" if vbe else "" %} + {%- if is_forward %} {#-/* PT2 wrapper function for backward grad_indice_weights CPU */#} -Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( +Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( const Tensor& indices, const Tensor& offsets, const Tensor& /*lxu_cache_locations*/, - const Tensor& feature_requires_grad) { + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow( @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( {% if is_forward %} {#-/* PT2 wrapper function for forward CPU */#} -Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& host_weights, const Tensor& /*dev_weights*/, const Tensor& /*uvm_weights*/, @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& indice_weights, const Tensor& /*lxu_cache_locations*/, const Tensor& /*uvm_cache_stats*/, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/ + const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/ + const c10::SymInt vbe_output_size, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- endif %} const bool /*is_experimental = false*/, const int64_t output_dtype = static_cast(SparseType::FP32)) { - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") - .typed(); + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") + .typed(); + {%- if vbe %} + // TODO: remove this after vbe is implemented for CPU kernel + Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map; + Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets; + const auto output = op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + auto options = at::TensorOptions() + .dtype(output.options().dtype()) + .device(host_weights.options().device()); + const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__); + Tensor output_new = at::empty({vbe_output_size_}, options); + const int32_t T = D_offsets.numel() - 1; + const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1; - return op.call( - host_weights, - weights_offsets, - D_offsets, - total_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - output_dtype); -} + for (int32_t r = 0; r < R; r++){ + auto D_offset = 0; + for (int32_t t = 0; t < T; t++){ + const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item(); + const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item(); + const int32_t D = D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item(); + + TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D)); + auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten(); + output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values); + D_offset += D; + } + } + return output_new; + {%- else %} + return op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + {%- endif %} + } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap const int64_t /*BT_block_size*/, const int64_t /*max_segment_length_per_warp*/, const bool stochastic_rounding, - const int64_t /*info_B_num_bits*/, - const int64_t /*info_B_mask_int64*/, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- if vbe %} + const Tensor& B_offsets, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + {%- endif %} const bool /*use_uniq_cache_locations*/, const bool /*use_homogeneous_placements*/, {{ args_pt2.split_function_args | join(", ") }} @@ -194,29 +258,30 @@ namespace { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- if is_forward %} DISPATCH_TO_CPU( - "split_embedding_codegen_grad_indice_weights_pt2_wrapper", - split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper); + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper", + split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper); {%- endif %} {%- for weighted in [True, False] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- if is_forward %} - {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format( - wdesc + {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format( + wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( - optimizer, wdesc + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format( + optimizer, wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper); {%- endif %} {%- endfor %} {#-/*for weighted*/#} } - } // namespace +{%- endfor %} {#-/* for vbe in [True, False] */#} + {% endif %} // if has_cpu_support // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp new file mode 100644 index 0000000000..d79beb3e2b --- /dev/null +++ b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +// #include +// #include +// #include "fbgemm_gpu/embedding_common.h" +// #include "fbgemm_gpu/utils/dispatch_macros.h" +// #include "fbgemm_gpu/utils/ops_utils.h" +// #include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets) { + /* FOR CPU VBE to use the same backend */ + const auto T = D_offsets.numel() - 1; + int32_t max_B = 0; + int32_t total_D = 0; + // find max_B, total_D to create output [max_B, total_D] + for (int32_t t = 0; t < T; t++) { + auto b = B_offsets[t + 1].item() - B_offsets[t].item(); + max_B = std::max(max_B, b); + total_D += D_offsets[t + 1].item() - D_offsets[t].item(); + } + auto grad_output_ = at::empty({max_B, total_D}, grad_output.options()); + // for each feature + auto offset = 0; + + const int32_t R = B_offsets_rank_per_feature.size(1) - 1; + for (int32_t r = 0; r < R; r++) { + auto D_offset = 0; + for (int32_t t = 0; t < T; t++) { + const int32_t b_begin = B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = + B_offsets_rank_per_feature[t][r + 1].item(); + const int32_t D = + D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b = b_end - b_begin; + const int32_t num_elm = b * D; + auto values = grad_output.slice(0, offset, offset + num_elm); + values = values.reshape({b, D}); + grad_output_.index_put_( + {at::indexing::Slice(b_begin, b_end), + at::indexing::Slice(D_offset, D_offset + D)}, + values); + D_offset += D; + offset += num_elm; + } + } + return grad_output_; +} +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h new file mode 100644 index 0000000000..3aff58c9af --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +// #include +// #include +// #include "fbgemm_gpu/embedding_common.h" +// #include "fbgemm_gpu/utils/dispatch_macros.h" +// #include "fbgemm_gpu/utils/ops_utils.h" +// #include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index e8a8017d78..2181c5087c 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -30,6 +30,11 @@ test_st: Dict[str, Any] = common_strategy.copy() test_st["D"] = st.integers(min_value=2, max_value=128) +test_st_cpu = test_st.copy() +test_st_cpu["use_cpu"] = st.just(True) +test_st_cpu["output_dtype"] = st.sampled_from([SparseType.FP32, SparseType.FP16]) +test_st_cpu["row_wise"] = st.just(True) + @optests.generate_opcheck_tests(fast=True) class BackwardAdagradTest(unittest.TestCase): @@ -100,6 +105,28 @@ def test_backward_adagrad_fp32_pmSUM( # noqa C901 **kwargs, ) + @given( + compile=st.booleans(), + pooling_mode=st.sampled_from([PoolingMode.SUM, PoolingMode.MEAN]), + **test_st_cpu, + ) + @settings(**common_settings) + def test_backward_adagrad_fp32_cpu( # noqa C901 + self, + pooling_mode: PoolingMode, + **kwargs: Any, + ) -> None: + """ + Test VBE support for CPU on rowwise adagrad + """ + kwargs = adjust_mixed_B_st(kwargs) + execute_backward_adagrad( + weights_precision=SparseType.FP32, + pooling_mode=pooling_mode, + mixed_B=True, + **kwargs, + ) + @given( mixed_B=st.booleans(), compile=st.booleans(), diff --git a/fbgemm_gpu/test/tbe/training/backward_sgd_test.py b/fbgemm_gpu/test/tbe/training/backward_sgd_test.py index dc91c75805..603ca45d52 100644 --- a/fbgemm_gpu/test/tbe/training/backward_sgd_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_sgd_test.py @@ -391,6 +391,66 @@ def test_backward_sgd( # noqa C901 SparseType.FP32, # output_dtype ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), + weighted=st.booleans(), + mixed=st.booleans(), + use_cache=st.booleans(), + cache_algorithm=st.sampled_from(CacheAlgorithm), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + ] + ), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + def test_backward_sgd_vbe_cpu( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + mixed: bool, + use_cache: bool, + cache_algorithm: CacheAlgorithm, + long_segments: bool, + pooling_mode: PoolingMode, + ) -> None: + use_cpu = True + mixed_B = True + self.execute_backward_sgd_( + T, + D, + B, + log_E, + L, + weights_precision, + weighted, + mixed, + mixed_B if not use_cpu else False, + use_cache, + cache_algorithm, + long_segments, + pooling_mode, + use_cpu, + SparseType.FP32, # output_dtype + ) + @given( D=st.integers(min_value=2, max_value=10), # 128 * 1024 is to exercise a case num_ctas_for_run needs to be capped diff --git a/fbgemm_gpu/test/tbe/training/failures_dict_fast.json b/fbgemm_gpu/test/tbe/training/failures_dict_fast.json index e419a1a34c..f2986b7bbb 100644 --- a/fbgemm_gpu/test/tbe/training/failures_dict_fast.json +++ b/fbgemm_gpu/test/tbe/training/failures_dict_fast.json @@ -371,6 +371,12 @@ "fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {}, + "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_pt2": { + "BackwardAdagradTest.test_faketensor__test_backward_adagrad_fp32_cpu": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function": {}, "fbgemm::split_embedding_codegen_lookup_rowwise_weighted_adagrad_function": {}, "fbgemm::split_embedding_codegen_lookup_sgd_function": {},