diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp index 705eaa61b0..71a2294864 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp @@ -224,6 +224,32 @@ void nccl_allreduce( default: TORCH_CHECK(false, "unsupported type: ", src.scalar_type()); } +#if defined(USE_ROCM) + if (bias) { + C10D_NCCL_CHECK( + ncclAllReduceWithBias( + src.data_ptr(), + dst.data_ptr(), + src.numel(), + type, + ncclSum, + *get_nccl_comm(comm_idx), + at::cuda::getCurrentCUDAStream(), + (*bias).data_ptr()), + "ncclAllReduceWithBias"); + } else { + C10D_NCCL_CHECK( + ncclAllReduce( + src.data_ptr(), + dst.data_ptr(), + src.numel(), + type, + ncclSum, + *get_nccl_comm(comm_idx), + at::cuda::getCurrentCUDAStream()), + "ncclAllReduce"); + } +#else C10D_NCCL_CHECK( ncclAllReduce( src.data_ptr(), @@ -237,6 +263,7 @@ void nccl_allreduce( if (bias) { dst.add_(*bias); } +#endif } } // namespace