Skip to content

Commit

Permalink
[XLA:SPMD] Fix scatter index-parallel partitioning where operand/upda…
Browse files Browse the repository at this point in the history
…tes sharding should be aligned with indices' sharding.

PiperOrigin-RevId: 675439880
  • Loading branch information
Tongfei-Guo authored and Google-ML-Automation committed Sep 17, 2024
1 parent b3da4b6 commit fc31e16
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,11 +684,13 @@ absl::StatusOr<HloInstruction*> PartitionGatherIndexParallelDimensions(
hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(),
indices_parallel_dims);
const GroupedSharding operand_grouped =
hlo_sharding_util::GroupShardingOnDims(operand.sharding(),
operand_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operand.sharding(), operand_parallel_dims),
new_indices_grouped);
const GroupedSharding output_grouped =
hlo_sharding_util::GroupShardingOnDims(gather_output_sharding,
output_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
gather_output_sharding, output_parallel_dims),
new_indices_grouped);
PartitionedHlo per_group_operand =
PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups);
PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo(
Expand Down Expand Up @@ -1130,11 +1132,13 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexParallelDimensions(
hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(),
indices_parallel_dims);
const GroupedSharding operand_grouped =
hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(),
operand_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operands[0].sharding(), operand_parallel_dims),
new_indices_grouped);
const GroupedSharding update_grouped =
hlo_sharding_util::GroupShardingOnDims(updates[0].sharding(),
update_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
updates[0].sharding(), update_parallel_dims),
new_indices_grouped);
const GroupedSharding& output_grouped = operand_grouped;
std::vector<PartitionedHlo> per_group_operands =
PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);
Expand Down

0 comments on commit fc31e16

Please sign in to comment.