diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index ccf4805cd3..99d5e5b0c5 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -263,10 +263,10 @@ list(APPEND gen_gpu_host_source_files foreach(optimizer ${ALL_OPTIMIZERS}) list(APPEND gen_cpu_source_files "gen_embedding_backward_split_${optimizer}_cpu.cpp" - "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp") + "gen_embedding_backward_split_${optimizer}_pt2_cpu_wrapper.cpp" + "gen_embedding_split_${optimizer}_pt2_autograd.cpp") list(APPEND gen_gpu_host_source_files "gen_embedding_backward_split_${optimizer}.cpp" - "gen_embedding_split_${optimizer}_pt2_autograd.cpp" "gen_embedding_backward_split_${optimizer}_pt2_cuda_wrapper.cpp") endforeach() @@ -454,6 +454,7 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/forward/embedding_forward_split_cpu.cpp codegen/inference/embedding_forward_quantized_host_cpu.cpp codegen/training/backward/embedding_backward_dense_host_cpu.cpp + codegen/training/pt2/pt2_autograd_utils.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/config/feature_gates.cpp src/memory_utils/memory_utils.cpp @@ -480,6 +481,7 @@ set(fbgemm_gpu_sources_static_cpu src/split_embeddings_cache/lru_cache_populate_byte.cpp src/split_embeddings_cache/lxu_cache.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cpp + src/split_embeddings_utils/split_embeddings_utils_cpu.cpp codegen/training/index_select/batch_index_select_dim0_ops.cpp codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index 285cf9a559..894ce104ce 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -83,6 +83,7 @@ def generate_pt2_wrappers() -> None: f"gen_embedding_forward_split_pt2_cpu_wrapper.cpp", has_cpu_support=True, is_forward=True, + has_vbe_support=True, ) # Generate PT2 forward wrapper (CUDA) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index b623b92d03..3666de5b9a 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -35,8 +35,12 @@ #include "fbgemm_gpu/utils/ops_utils.h" #include #include "fbgemm_gpu/utils/dispatch_macros.h" -#include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/split_embeddings_utils.h" #include "fbgemm_gpu/config/feature_gates.h" +{%- if has_vbe_support %} +#include "fbgemm_gpu/utils/pt2_autograd_utils.h" +{%- endif %} using Tensor = at::Tensor; @@ -236,9 +240,9 @@ enum SSDTensor { const Tensor& /*prev_iter_dev*/, {%- endif %} {%- if "iter" not in args_pt2.split_function_arg_names %} - const int64_t iter, + const int64_t /*iter*/, {%- endif %} - const double gwd_lower_bound, + const double /*gwd_lower_bound*/, {%- endif %} {# /* if is_gwd */ #} {%- for arg_type in args_pt2.split_function_args %} {{ arg_type.split(' ')[0]}}{%- if not loop.last %}{{ "," }}{%- endif %} @@ -617,7 +621,6 @@ class {{ autograd_func }} : const c10::SymInt, const int64_t, const c10::SymInt)>(); - auto [ vbe_row_output_offsets, vbe_b_t_map @@ -850,6 +853,11 @@ static torch::autograd::variable_list backward( // {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda) weights_dev = weights_dev.flatten(); {%- endif %} + {%- if vbe %} + if (weights_host.numel() > 1){ + grad_output = reshape_vbe_output(grad_output, B_offsets, vbe_b_t_map, D_offsets); + } + {%- endif %} {%- set grad_indice_weights_op = "{}_embedding_codegen_grad_indice_weights{}_pt2_wrapper".format(fwd_mdesc, vdesc) @@ -883,7 +891,7 @@ static torch::autograd::variable_list backward( {%- else %} const Tensor& /*feature_requires_grad*/ {%- endif %} - )>(); + )>(); const auto grad_indice_weights = !indice_weights.defined() ? Variable() : @@ -1014,7 +1022,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2( {%- if not ssd %} {%- if has_vbe_support %} // has vbe support and on gpu - if (B_offsets.has_value() && !(weights[0].numel() > 0)) { + if (B_offsets.has_value()) { {%- if has_global_weight_decay_support %} // vbe and has gwd support if (apply_global_weight_decay && weight_decay > 0) { diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp index c74355207f..5b2b066fee 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_cpu_wrapper_template.cpp @@ -30,9 +30,12 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; +{%- for vbe in ([True, False] if has_vbe_support else [False]) %} +{%- set vdesc = "_vbe" if vbe else "" %} + {%- if is_forward %} {#-/* PT2 wrapper function for backward grad_indice_weights CPU */#} -Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( +Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -45,7 +48,16 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( const Tensor& indices, const Tensor& offsets, const Tensor& /*lxu_cache_locations*/, - const Tensor& feature_requires_grad) { + {%- if vbe %} + const Tensor& feature_requires_grad, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64 + {%- else %} + const Tensor& feature_requires_grad + {%- endif %} +) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow( @@ -67,7 +79,7 @@ Tensor split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper( {% if is_forward %} {#-/* PT2 wrapper function for forward CPU */#} -Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_codegen_forward_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& host_weights, const Tensor& /*dev_weights*/, const Tensor& /*uvm_weights*/, @@ -84,30 +96,77 @@ Tensor split_embedding_codegen_forward_{{ wdesc }}_pt2_cpu_wrapper( const Tensor& indice_weights, const Tensor& /*lxu_cache_locations*/, const Tensor& /*uvm_cache_stats*/, + {%- if vbe %} + const Tensor& vbe_row_output_offsets, /*vbe_output_offsets_feature_rank*/ + const Tensor& vbe_b_t_map, /*vbe_B_offsets_rank_per_feature*/ + const c10::SymInt vbe_output_size, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- endif %} const bool /*is_experimental = false*/, const int64_t output_dtype = static_cast(SparseType::FP32)) { - static auto op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") - .typed(); + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") + .typed(); + {%- if vbe %} + // TODO: remove this after vbe is implemented for CPU kernel + Tensor vbe_B_offsets_rank_per_feature = vbe_b_t_map; + Tensor vbe_output_offsets_feature_rank = vbe_row_output_offsets; + const auto output = op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + auto options = at::TensorOptions() + .dtype(output.options().dtype()) + .device(host_weights.options().device()); + const int64_t vbe_output_size_ = vbe_output_size.guard_int(__FILE__, __LINE__); + Tensor output_new = at::empty({vbe_output_size_}, options); + const int32_t T = D_offsets.numel() - 1; + const int32_t R = vbe_B_offsets_rank_per_feature.size(1) - 1; - return op.call( - host_weights, - weights_offsets, - D_offsets, - total_D, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - output_dtype); -} + for (int32_t r = 0; r < R; r++){ + auto D_offset = 0; + for (int32_t t = 0; t < T; t++){ + const int32_t o_begin = vbe_output_offsets_feature_rank[r * T + t].item(); + const int32_t o_end = vbe_output_offsets_feature_rank[r * T + t + 1].item(); + const int32_t D = D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b_begin = vbe_B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = vbe_B_offsets_rank_per_feature[t][r + 1].item(); + + TORCH_CHECK((o_end - o_begin) == ((b_end - b_begin) * D)); + auto values = output.index({torch::indexing::Slice(b_begin, b_end), torch::indexing::Slice(D_offset, D_offset + D)}).flatten(); + output_new.index_put_({torch::indexing::Slice(o_begin, o_end)}, values); + D_offset += D; + } + } + return output_new; + {%- else %} + return op.call( + host_weights, + weights_offsets, + D_offsets, + total_D, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + output_dtype); + {%- endif %} + } {% else %} {#-/* PT2 wrapper function for backward CPU */#} -Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrapper( +Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_pt2_cpu_wrapper( const Tensor& grad_output, const Tensor& host_weights, const Tensor& /*dev_weights*/, @@ -127,8 +186,13 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}_pt2_cpu_wrap const int64_t /*BT_block_size*/, const int64_t /*max_segment_length_per_warp*/, const bool stochastic_rounding, - const int64_t /*info_B_num_bits*/, - const int64_t /*info_B_mask_int64*/, + const int64_t info_B_num_bits, + const int64_t info_B_mask_int64, + {%- if vbe %} + const Tensor& B_offsets, + const Tensor& vbe_row_output_offsets, + const Tensor& vbe_b_t_map, + {%- endif %} const bool /*use_uniq_cache_locations*/, const bool /*use_homogeneous_placements*/, {{ args_pt2.split_function_args | join(", ") }} @@ -194,29 +258,30 @@ namespace { TORCH_LIBRARY_FRAGMENT(fbgemm, m) { {%- if is_forward %} DISPATCH_TO_CPU( - "split_embedding_codegen_grad_indice_weights_pt2_wrapper", - split_embedding_codegen_grad_indice_weights_pt2_cpu_wrapper); + "split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper", + split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper); {%- endif %} {%- for weighted in [True, False] %} {%- set wdesc = "weighted" if weighted else "unweighted" %} {%- if is_forward %} - {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}_pt2".format( - wdesc + {%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format( + wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper); {%- else %} - {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}_pt2".format( - optimizer, wdesc + {%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format( + optimizer, wdesc, vdesc ) %} DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper); {%- endif %} {%- endfor %} {#-/*for weighted*/#} } - } // namespace +{%- endfor %} {#-/* for vbe in [True, False] */#} + {% endif %} // if has_cpu_support // clang-format on diff --git a/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp new file mode 100644 index 0000000000..071acf90aa --- /dev/null +++ b/fbgemm_gpu/codegen/training/pt2/pt2_autograd_utils.cpp @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets) { + /* FOR CPU VBE to use the same backend */ + const auto T = D_offsets.numel() - 1; + int32_t max_B = 0; + int32_t total_D = 0; + // find max_B, total_D to create output [max_B, total_D] + for (int32_t t = 0; t < T; t++) { + auto b = B_offsets[t + 1].item() - B_offsets[t].item(); + max_B = std::max(max_B, b); + total_D += D_offsets[t + 1].item() - D_offsets[t].item(); + } + auto grad_output_ = at::empty({max_B, total_D}, grad_output.options()); + // for each feature + auto offset = 0; + + const int32_t R = B_offsets_rank_per_feature.size(1) - 1; + for (int32_t r = 0; r < R; r++) { + auto D_offset = 0; + for (int32_t t = 0; t < T; t++) { + const int32_t b_begin = B_offsets_rank_per_feature[t][r].item(); + const int32_t b_end = + B_offsets_rank_per_feature[t][r + 1].item(); + const int32_t D = + D_offsets[t + 1].item() - D_offsets[t].item(); + const int32_t b = b_end - b_begin; + const int32_t num_elm = b * D; + auto values = grad_output.slice(0, offset, offset + num_elm); + values = values.reshape({b, D}); + grad_output_.index_put_( + {at::indexing::Slice(b_begin, b_end), + at::indexing::Slice(D_offset, D_offset + D)}, + values); + D_offset += D; + offset += num_elm; + } + } + return grad_output_; +} +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h new file mode 100644 index 0000000000..3aff58c9af --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/pt2_autograd_utils.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +// #include +// #include +// #include "fbgemm_gpu/embedding_common.h" +// #include "fbgemm_gpu/utils/dispatch_macros.h" +// #include "fbgemm_gpu/utils/ops_utils.h" +// #include "fbgemm_gpu/utils/tensor_utils.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +//////////////////////////////////////////////////////////////////////////////// +// Helper Functions +//////////////////////////////////////////////////////////////////////////////// + +Tensor reshape_vbe_output( + const Tensor& grad_output, + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& D_offsets); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu index c3eb40819d..a4efd4c212 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/get_infos_metadata.cu @@ -13,52 +13,6 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; -DLL_PUBLIC std::tuple adjust_info_B_num_bits( - int32_t B, - int32_t T) { - int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; - uint32_t info_B_mask = DEFAULT_INFO_B_MASK; - uint32_t max_T = MAX_T; - uint32_t max_B = MAX_B; - bool invalid_T = T > max_T; - bool invalid_B = B > max_B; - - TORCH_CHECK( - !(invalid_T && invalid_B), - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - if (invalid_T) { - // Reduce info_B_num_bits - while (invalid_T && !invalid_B && info_B_num_bits > 0) { - info_B_num_bits--; - max_T = ((max_T + 1) << 1) - 1; - max_B = ((max_B + 1) >> 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } else if (invalid_B) { - // Increase info_B_num_bits - while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { - info_B_num_bits++; - max_T = ((max_T + 1) >> 1) - 1; - max_B = ((max_B + 1) << 1) - 1; - invalid_T = T > max_T; - invalid_B = B > max_B; - } - } - - TORCH_CHECK( - !invalid_T && !invalid_B, - "Not enough infos bits to accommodate T and B. Default num bits = ", - DEFAULT_INFO_NUM_BITS); - - // Recompute info_B_mask using new info_B_num_bits - info_B_mask = (1u << info_B_num_bits) - 1; - - return {info_B_num_bits, info_B_mask}; -} - DLL_PUBLIC std::tuple get_infos_metadata(Tensor unused, int64_t B, int64_t T) { return adjust_info_B_num_bits(B, T); diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp index 4ae9ae0f70..8902e1c446 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp @@ -33,58 +33,12 @@ generate_vbe_metadata_meta( return {row_output_offsets, b_t_map}; } -std::tuple -generate_vbe_metadata_cpu( - const Tensor& B_offsets, - const Tensor& B_offsets_rank_per_feature, - const Tensor& output_offsets_feature_rank, - const Tensor& D_offsets, - const int64_t D, - const bool nobag, - const c10::SymInt max_B_feature_rank, - const int64_t info_B_num_bits, - const c10::SymInt total_B) { - Tensor row_output_offsets = output_offsets_feature_rank; - Tensor b_t_map = B_offsets_rank_per_feature; - return {row_output_offsets, b_t_map}; -} - } // namespace TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.def( - "transpose_embedding_input(" - " Tensor hash_size_cumsum, " - " int total_hash_size_bits, " - " Tensor indices, " - " Tensor offsets, " - " bool nobag=False, " - " Tensor? vbe_b_t_map=None, " - " int info_B_num_bits=26, " - " int info_B_mask=0x2FFFFFF, " - " int total_unique_indices=-1, " - " bool is_index_select=False, " - " Tensor? total_L_offsets=None, " - " int fixed_L_per_warp=0, " - " int num_warps_per_feature=0" - ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); - m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); - m.def( - "generate_vbe_metadata(" - " Tensor B_offsets, " - " Tensor B_offsets_rank_per_feature, " - " Tensor output_offsets_feature_rank, " - " Tensor D_offsets, " - " int D, " - " bool nobag, " - " SymInt max_B_feature_rank, " - " int info_B_num_bits, " - " SymInt total_B" - ") -> (Tensor, Tensor)"); DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); DISPATCH_TO_CUDA("get_infos_metadata", get_infos_metadata); DISPATCH_TO_CUDA("generate_vbe_metadata", generate_vbe_metadata); - DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { diff --git a/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp new file mode 100644 index 0000000000..654a3c3edc --- /dev/null +++ b/fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils_cpu.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "fbgemm_gpu/split_embeddings_utils.h" +#include "fbgemm_gpu/utils/ops_utils.h" + +using Tensor = at::Tensor; + +DLL_PUBLIC std::tuple adjust_info_B_num_bits( + int32_t B, + int32_t T) { + int32_t info_B_num_bits = DEFAULT_INFO_B_NUM_BITS; + uint32_t info_B_mask = DEFAULT_INFO_B_MASK; + uint32_t max_T = MAX_T; + uint32_t max_B = MAX_B; + bool invalid_T = T > max_T; + bool invalid_B = B > max_B; + + TORCH_CHECK( + !(invalid_T && invalid_B), + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + if (invalid_T) { + // Reduce info_B_num_bits + while (invalid_T && !invalid_B && info_B_num_bits > 0) { + info_B_num_bits--; + max_T = ((max_T + 1) << 1) - 1; + max_B = ((max_B + 1) >> 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } else if (invalid_B) { + // Increase info_B_num_bits + while (!invalid_T && invalid_B && info_B_num_bits < DEFAULT_INFO_NUM_BITS) { + info_B_num_bits++; + max_T = ((max_T + 1) >> 1) - 1; + max_B = ((max_B + 1) << 1) - 1; + invalid_T = T > max_T; + invalid_B = B > max_B; + } + } + + TORCH_CHECK( + !invalid_T && !invalid_B, + "Not enough infos bits to accommodate T and B. Default num bits = ", + DEFAULT_INFO_NUM_BITS); + + // Recompute info_B_mask using new info_B_num_bits + info_B_mask = (1u << info_B_num_bits) - 1; + + return {info_B_num_bits, info_B_mask}; +} + +namespace { + +std::tuple +generate_vbe_metadata_cpu( + const Tensor& B_offsets, + const Tensor& B_offsets_rank_per_feature, + const Tensor& output_offsets_feature_rank, + const Tensor& D_offsets, + const int64_t D, + const bool nobag, + const c10::SymInt max_B_feature_rank, + const int64_t info_B_num_bits, + const c10::SymInt total_B) { + Tensor row_output_offsets = output_offsets_feature_rank; + Tensor b_t_map = B_offsets_rank_per_feature; + return {row_output_offsets, b_t_map}; +} + +std::tuple +get_infos_metadata_cpu(Tensor unused, int64_t B, int64_t T) { + return adjust_info_B_num_bits(B, T); +} + +} // namespace + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "transpose_embedding_input(" + " Tensor hash_size_cumsum, " + " int total_hash_size_bits, " + " Tensor indices, " + " Tensor offsets, " + " bool nobag=False, " + " Tensor? vbe_b_t_map=None, " + " int info_B_num_bits=26, " + " int info_B_mask=0x2FFFFFF, " + " int total_unique_indices=-1, " + " bool is_index_select=False, " + " Tensor? total_L_offsets=None, " + " int fixed_L_per_warp=0, " + " int num_warps_per_feature=0" + ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); + m.def( + "generate_vbe_metadata(" + " Tensor B_offsets, " + " Tensor B_offsets_rank_per_feature, " + " Tensor output_offsets_feature_rank, " + " Tensor D_offsets, " + " int D, " + " bool nobag, " + " SymInt max_B_feature_rank, " + " int info_B_num_bits, " + " SymInt total_B" + ") -> (Tensor, Tensor)"); + DISPATCH_TO_CPU("generate_vbe_metadata", generate_vbe_metadata_cpu); + DISPATCH_TO_CPU("get_infos_metadata", get_infos_metadata_cpu); +}