From 154a4cc4f463f409240309ae12294be5b1e79673 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Fri, 10 Jan 2025 14:34:48 -0800 Subject: [PATCH] Add support for `int32_t` indices in TBE training (2H/N) (#3539) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/626 - Update benchmark test for `int32_t` Indicies Reviewed By: sryap Differential Revision: D67784746 --- fbgemm_gpu/bench/bench_utils.py | 2 + ...plit_table_batched_embeddings_benchmark.py | 13 ++- .../embedding_backward_split_cpu_template.cpp | 87 +++++++++++-------- ...t_table_batched_embeddings_ops_training.py | 12 ++- 4 files changed, 71 insertions(+), 43 deletions(-) diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index 830415945c..36e9a69ec6 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -168,6 +168,8 @@ def benchmark_requests( if num_warmups > 0: indices, offsets, weights = requests[0].unpack_3() + print(f"INDICES BENCHMARK {indices.dtype}") + print(f"OFFSETS BENCHMARK {offsets.dtype}") for _ in range(num_warmups): out = func(indices, offsets, weights) if bwd_only: diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index f439ed6780..a9f0de9d78 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -125,6 +125,7 @@ def cli() -> None: @click.option("--flush-gpu-cache-size-mb", default=0) @click.option("--dense", is_flag=True, default=False) @click.option("--output-dtype", type=SparseType, default=SparseType.FP32) +@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64") @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) @click.option("--export-trace", is_flag=True, default=False) @@ -166,6 +167,7 @@ def device( # noqa C901 flush_gpu_cache_size_mb: int, dense: bool, output_dtype: SparseType, + indices_dtype: str, requests_data_file: Optional[str], tables: Optional[str], export_trace: bool, @@ -176,6 +178,9 @@ def device( # noqa C901 cache_load_factor: float, ) -> None: assert not ssd or not dense, "--ssd cannot be used together with --dense" + indices_dtype_torch: torch.dtype = ( + torch.int32 if int(indices_dtype) == 32 else torch.int64 + ) np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]): time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb.forward( - indices.long(), - offsets.long(), + indices.to(dtype=indices_dtype_torch), + offsets.to(dtype=indices_dtype_torch), per_sample_weights, feature_requires_grad=feature_requires_grad, ), @@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]): time_per_iter = benchmark_requests( requests, lambda indices, offsets, per_sample_weights: emb( - indices.long(), - offsets.long(), + indices.to(dtype=indices_dtype_torch), + offsets.to(dtype=indices_dtype_torch), per_sample_weights, feature_requires_grad=feature_requires_grad, ), diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp index 29cc9eb8b8..b2c0bb4785 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp @@ -40,7 +40,7 @@ struct half2float16 { } // namespace internal namespace { -template +template void split_embedding_backward_exact_cpu_kernel( Tensor grad_output, Tensor host_weights, @@ -90,8 +90,8 @@ for (const auto t : c10::irange(num_tables)) { ::internal::csr2csc( cscs[t], B, - offsets.accessor(), - indices.accessor(), + offsets.accessor(), + indices.accessor(), indice_weights.defined() ? indice_weights.accessor, 1>() : at::TensorAccessor, 1>(nullptr, nullptr, nullptr), @@ -223,7 +223,7 @@ for (const auto d : c10::irange(D)) { } // for each table } -template +template void split_embedding_backward_exact_cpu_dense_kernel( Tensor grad, Tensor grad_output, @@ -240,8 +240,8 @@ void split_embedding_backward_exact_cpu_dense_kernel( auto grad_output_data = grad_output.accessor(); - const auto indices_data = indices.accessor(); - const auto offsets_data = offsets.accessor(); + const auto indices_data = indices.accessor(); + const auto offsets_data = offsets.accessor(); const auto indice_weights_data = indice_weights.defined() ? // If indice_weights are not defined, then this accessor won't be @@ -347,34 +347,42 @@ for (const auto d : c10::irange(D)) { grad_output = grad_output.contiguous(); - - FBGEMM_DISPATCH_FLOAT_AND_HALF( + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), + "split_embedding_backward_exact_cpu_kernel_1", [&] { + using index_t = scalar_t; + + FBGEMM_DISPATCH_FLOAT_AND_HALF( grad_output.scalar_type(), - "split_embedding_backward_exact_cpu_outer", [&]() { - using grad_t = scalar_t; + "split_embedding_backward_exact_cpu_kernel_2", [&] { + using grad_t = scalar_t; + FBGEMM_DISPATCH_FLOAT_AND_HALF( - host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] { - split_embedding_backward_exact_cpu_kernel( - grad_output, - host_weights, - weights_offsets_data, - D_offsets_data, - hash_size_cumsum, - indices, - offsets, - pooling_mode, - indice_weights, - num_tables, - B, - table_to_feature_offset, - {% if "momentum1_offsets" in args.split_function_arg_names %} - momentum1_offsets_data, - {% endif %} - {% if "momentum2_offsets" in args.split_function_arg_names %} - momentum2_offsets_data, - {% endif %} - {{ args.split_cpu_kernel_arg_constructors | join(", ") }}); - }); + host_weights.scalar_type(), + "split_embedding_backward_exact_cpu_kernel_3", [&] { + + split_embedding_backward_exact_cpu_kernel( + grad_output, + host_weights, + weights_offsets_data, + D_offsets_data, + hash_size_cumsum, + indices, + offsets, + pooling_mode, + indice_weights, + num_tables, + B, + table_to_feature_offset, + {% if "momentum1_offsets" in args.split_function_arg_names %} + momentum1_offsets_data, + {% endif %} + {% if "momentum2_offsets" in args.split_function_arg_names %} + momentum2_offsets_data, + {% endif %} + {{ args.split_cpu_kernel_arg_constructors | join(", ") }}); + }); + }); }); return; @@ -383,10 +391,16 @@ for (const auto d : c10::irange(D)) { // When input is dense enough, avoid sorting and just treat as dense. auto grad = zeros_like(host_weights, grad_output.dtype()); - FBGEMM_DISPATCH_FLOAT_AND_HALF( - grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] { + FBGEMM_DISPATCH_INTEGRAL_TYPES( + indices.scalar_type(), + "split_embedding_backward_exact_cpu_dense_kernel", [&] { + using index_t = scalar_t; - split_embedding_backward_exact_cpu_dense_kernel( + FBGEMM_DISPATCH_FLOAT_AND_HALF( + grad_output.scalar_type(), + "split_embedding_backward_exact_cpu", [&] { + + split_embedding_backward_exact_cpu_dense_kernel( grad, grad_output, weights_offsets_data, @@ -398,7 +412,8 @@ for (const auto d : c10::irange(D)) { num_tables, B, table_to_feature_offset); - }); // dispatch host_weights.scalar_type() + }); + }); return grad; {% endif %} diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 8f8d5779ea..cbdb6645c9 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -3315,8 +3315,10 @@ def prepare_inputs( ) if force_cast_input_types: - # Force casting indices and offsets to long - (indices, offsets) = indices.long(), offsets.long() + # NOTE: Force offsets to have the same dtype as indices since the + # kernels assume same dtype. We might need to revisit the assumption + # of same dtypes in the future. + offsets = offsets.to(dtype=indices.dtype) # Force casting per_sample_weights to float if per_sample_weights is not None: @@ -3681,7 +3683,11 @@ def forward( offsets, batch_size_per_feature_per_rank ) - (indices, offsets) = indices.long(), offsets.long() + # NOTE: Force offsets to have the same dtype as indices since the + # kernels assume same dtype. We might need to revisit the assumption + # of same dtypes in the future. + offsets = offsets.to(dtype=indices.dtype) + # Force casting per_sample_weights to float if per_sample_weights is not None: per_sample_weights = per_sample_weights.float()