diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index dd23dca852..99d5e5b0c5 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -263,10 +263,10 @@ list(APPEND gen_gpu_host_source_files foreach(optimizer ${ALL_OPTIMIZERS}) list(APPEND gen_cpu_source_files "gen_embedding_backward_split_${optimizer}_cpu.cpp" - "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp") + "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp" + "gen_embedding_split_${optimizer}_pt2_autograd.cpp") list(APPEND gen_gpu_host_source_files "gen_embedding_backward_split_${optimizer}.cpp" - "gen_embedding_split_${optimizer}_pt2_autograd.cpp" "gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp") endforeach() @@ -454,6 +454,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/forward/embedding_forward_split_cpu.cpp codegen/inference/embedding_forward_quantized_host_cpu.cpp codegen/training/backward/embedding_backward_dense_host_cpu.cpp + codegen/training/pt2/pt2_autograd_utils.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/config/feature_gates.cpp src/memory_utils/memory_utils.cpp diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 285cf9a559..894ce104ce 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None: f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp", has_cpu_support=True, is_forward=True, + has_vbe_support=True, ) # Generate PT2 forward wrapper (CUDA) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b623b92d03..3666de5b9a 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -35,8 +35,12 @@ #include "fbgemm_gpu/utils/ops_utils.h" #include #include "fbgemm_gpu/utils/dispatch_macros.h" -#include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/split_embeddings_utils.h" #include "fbgemm_gpu/config/feature_gates.h" +{%- if has_vbe_support %} +#include "fbgemm_gpu/utils/pt2_autograd_utils.h" +{%- endif %} using Tensor = at::Tensor; @@ -236,9 +240,9 @@ enum SSDTensor { const Tensor& /*prev_iter_dev*/, {%- endif %} {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, + const int64_t /*iter*/, {%- endif %} - const double gwd_lower_bound, + const double /*gwd_lower_bound*/, {%- endif %} {# /* if is_gwd */ #} {%- for arg_type in args_pt2.split_function_args %} {{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %} @@ -617,7 +621,6 @@ class {{ autograd_func }} : const c10::SymInt, const int64_t, const c10::SymInt)>(); - auto [ vbe_row_output_offsets, vbe_b_t_map @@ -850,6 +853,11 @@ static torch::autograd::variable_list backward( // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) weights_dev = weights_dev.flatten(); {%- endif %} + {%- if vbe %} + if (weights_host.numel() > 1){ + grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets); + } + {%- endif %} {%- set grad_indice_weights_op = "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) @@ -883,7 +891,7 @@ static torch::autograd::variable_list backward( {%- else %} const Tensor& /*feature_requires_grad*/ {%- endif %} - )>(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : @@ -1014,7 +1022,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if not ssd %} {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + if (B_offsets.has_value()) { {%- if has_global_weight_decay_support %} // vbe and has gwd support if (apply_global_weight_decay && weight_decay > 0) { diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index c74355207f..5b2b066fee 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -30,9 +30,12 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{%- for vbe in ([True, False] if has_vbe_support else [False]) %} +{%- set vdesc = "_vbe" if vbe else "" %} + {%- if is_forward %} {#-/* PT2 wrapper function for backward grad_indice_weights CPU */#} -Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( +Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( const Tensor& indices, const Tensor& offsets, const Tensor& /*lxu_cache_locations*/, - const Tensor& feature_requires_grad) { + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow( @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( {% if is_forward %} {#-/* PT2 wrapper function for forward CPU */#} -Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& host_weights, const Tensor& /*dev_weights*/, const Tensor& /*uvm_weights*/, @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& indice_weights, const Tensor& /*lxu_cache_locations*/, const Tensor& /*uvm_cache_stats*/, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/ + const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/ + const c10::SymInt vbe_output_size, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- endif %} const bool /*is_experimental = false*/, const int64_t output_dtype = static_cast(SparseType::FP32)) { - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") - .typed(); + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") + .typed(); + {%- if vbe %} + // TODO: remove this after vbe is implemented for CPU kernel + Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map; + Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets; + const auto output = op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + auto options = at::TensorOptions() + .dtype(output.options().dtype()) + .device(host_weights.options().device()); + const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__); + Tensor output_new = at::empty({vbe_output_size_}, options); + const int32_t T = D_offsets.numel() - 1; + const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1; - return op.call( - host_weights, - weights_offsets, - D_offsets, - total_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - output_dtype); -} + for (int32_t r = 0; r < R; r++){ + auto D_offset = 0; + for (int32_t t = 0; t < T; t++){ + const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item(); + const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item(); + const int32_t D = D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item(); + + TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D)); + auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten(); + output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values); + D_offset += D; + } + } + return output_new; + {%- else %} + return op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + {%- endif %} + } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap const int64_t /*BT_block_size*/, const int64_t /*max_segment_length_per_warp*/, const bool stochastic_rounding, - const int64_t /*info_B_num_bits*/, - const int64_t /*info_B_mask_int64*/, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- if vbe %} + const Tensor& B_offsets, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + {%- endif %} const bool /*use_uniq_cache_locations*/, const bool /*use_homogeneous_placements*/, {{ args_pt2.split_function_args | join(", ") }} @@ -194,29 +258,30 @@ namespace { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- if is_forward %} DISPATCH_TO_CPU( - "split_embedding_codegen_grad_indice_weights_pt2_wrapper", - split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper); + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper", + split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper); {%- endif %} {%- for weighted in [True, False] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- if is_forward %} - {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format( - wdesc + {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format( + wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( - optimizer, wdesc + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format( + optimizer, wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper); {%- endif %} {%- endfor %} {#-/*for weighted*/#} } - } // namespace +{%- endfor %} {#-/* for vbe in [True, False] */#} + {% endif %} // if has_cpu_support // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp new file mode 100644 index 0000000000..071acf90aa --- /dev/null +++ b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets) { + /* FOR CPU VBE to use the same backend */ + const auto T = D_offsets.numel() - 1; + int32_t max_B = 0; + int32_t total_D = 0; + // find max_B, total_D to create output [max_B, total_D] + for (int32_t t = 0; t < T; t++) { + auto b = B_offsets[t + 1].item() - B_offsets[t].item(); + max_B = std::max(max_B, b); + total_D += D_offsets[t + 1].item() - D_offsets[t].item(); + } + auto grad_output_ = at::empty({max_B, total_D}, grad_output.options()); + // for each feature + auto offset = 0; + + const int32_t R = B_offsets_rank_per_feature.size(1) - 1; + for (int32_t r = 0; r < R; r++) { + auto D_offset = 0; + for (int32_t t = 0; t < T; t++) { + const int32_t b_begin = B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = + B_offsets_rank_per_feature[t][r + 1].item(); + const int32_t D = + D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b = b_end - b_begin; + const int32_t num_elm = b * D; + auto values = grad_output.slice(0, offset, offset + num_elm); + values = values.reshape({b, D}); + grad_output_.index_put_( + {at::indexing::Slice(b_begin, b_end), + at::indexing::Slice(D_offset, D_offset + D)}, + values); + D_offset += D; + offset += num_elm; + } + } + return grad_output_; +} +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h new file mode 100644 index 0000000000..3aff58c9af --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +// #include +// #include +// #include "fbgemm_gpu/embedding_common.h" +// #include "fbgemm_gpu/utils/dispatch_macros.h" +// #include "fbgemm_gpu/utils/ops_utils.h" +// #include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets); +} // namespace fbgemm_gpu