From fc31e164cc03836106392793ab2350d937ec509b Mon Sep 17 00:00:00 2001 From: Tongfei Guo Date: Mon, 16 Sep 2024 23:55:20 -0700 Subject: [PATCH] [XLA:SPMD] Fix scatter index-parallel partitioning where operand/updates sharding should be aligned with indices' sharding. PiperOrigin-RevId: 675439880 --- xla/service/spmd/gather_scatter_handler.cc | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index 21cc82fc0ddf14..fd45e649f18cd2 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -684,11 +684,13 @@ absl::StatusOr PartitionGatherIndexParallelDimensions( hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), indices_parallel_dims); const GroupedSharding operand_grouped = - hlo_sharding_util::GroupShardingOnDims(operand.sharding(), - operand_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operand.sharding(), operand_parallel_dims), + new_indices_grouped); const GroupedSharding output_grouped = - hlo_sharding_util::GroupShardingOnDims(gather_output_sharding, - output_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + gather_output_sharding, output_parallel_dims), + new_indices_grouped); PartitionedHlo per_group_operand = PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups); PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( @@ -1130,11 +1132,13 @@ absl::StatusOr PartitionScatterIndexParallelDimensions( hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), indices_parallel_dims); const GroupedSharding operand_grouped = - hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(), - operand_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operands[0].sharding(), operand_parallel_dims), + new_indices_grouped); const GroupedSharding update_grouped = - hlo_sharding_util::GroupShardingOnDims(updates[0].sharding(), - update_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + updates[0].sharding(), update_parallel_dims), + new_indices_grouped); const GroupedSharding& output_grouped = operand_grouped; std::vector per_group_operands = PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);