diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 702f5a42a9..af9340afb6 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -1202,22 +1202,24 @@ Tensor dense_to_jagged_forward( at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dense.get_device()); +#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \ + AT_DISPATCH_CASE(TYPE, [&] { \ + jagged_dense_elementwise_jagged_output_opt_( \ + values, \ + offsets, \ + dense, \ + output, \ + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { \ + return y; \ + }); \ + }) + + // clang-format off AT_DISPATCH_SWITCH( values.scalar_type(), "dense_to_jagged_gpu_op_forward", - AT_DISPATCH_CASE( - at::ScalarType::Half, - [&] { - jagged_dense_elementwise_jagged_output_opt_( - values, - offsets, - dense, - output, - [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { - return y; - }); // device lambda - } // lambda - ) // CASE + DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Half) + DISPATCH_DENSE_TO_JAGGED_CASE(at::ScalarType::Int) AT_DISPATCH_CASE_FLOATING_TYPES_AND2( at::ScalarType::Long, at::ScalarType::BFloat16, @@ -1233,6 +1235,9 @@ Tensor dense_to_jagged_forward( } // lambda ) // CASE_FLOATING_TYPES_AND ); // SWITCH + // clang-format on + +#undef DISPATCH_DENSE_TO_JAGGED_CASE return output; }