Skip to content

Commit

Permalink
[XLA:SPMD] Fix scatter index-parallel partitioning issues.
Browse files Browse the repository at this point in the history
1. Fix gather/scatter partitioning where operand/updates sharding should be aligned with indices' sharding in index-parallel case.
2. Remove the assumption that gather/scatter index-parallel dim detection returns sorted dims where the dimension correspondence information is lost.

PiperOrigin-RevId: 675678391
  • Loading branch information
Tongfei-Guo authored and Google-ML-Automation committed Sep 17, 2024
1 parent 9e9fab0 commit f6b6175
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
20 changes: 11 additions & 9 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,6 @@ std::optional<GatherScatterParallelDims> GetGatherScatterBatchParallelDims(
index_parallel_in_dim[i] = -1;
}
}
absl::c_sort(indices_parallel_dims);
if (!indices_parallel_dims.empty()) {
return GatherScatterParallelDims{
indices_parallel_dims, operand_parallel_dims, index_parallel_in_dim};
Expand Down Expand Up @@ -2362,15 +2361,18 @@ GetGatherOutputOrScatterUpdateParallelDims(
int64_t index_vector_dim, absl::Span<const int64_t> offset_or_window_dims) {
absl::InlinedVector<int64_t, 1> output_parallel_dims;
auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) {
if (absl::c_linear_search(offset_or_window_dims, i)) {
continue;
}
const int index_dim = idx_dim < index_vector_dim ? idx_dim : idx_dim + 1;
if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
output_parallel_dims.push_back(i);
for (int64_t indices_parallel_dim : indices_parallel_dims) {
for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) {
if (absl::c_linear_search(offset_or_window_dims, i)) {
continue;
}
const int index_dim = idx_dim < index_vector_dim ? idx_dim : idx_dim + 1;
if (indices_parallel_dim == index_dim) {
output_parallel_dims.push_back(i);
break;
}
++idx_dim;
}
++idx_dim;
}
return output_parallel_dims;
}
Expand Down
20 changes: 12 additions & 8 deletions xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,11 +684,13 @@ absl::StatusOr<HloInstruction*> PartitionGatherIndexParallelDimensions(
hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(),
indices_parallel_dims);
const GroupedSharding operand_grouped =
hlo_sharding_util::GroupShardingOnDims(operand.sharding(),
operand_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operand.sharding(), operand_parallel_dims),
new_indices_grouped);
const GroupedSharding output_grouped =
hlo_sharding_util::GroupShardingOnDims(gather_output_sharding,
output_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
gather_output_sharding, output_parallel_dims),
new_indices_grouped);
PartitionedHlo per_group_operand =
PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups);
PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo(
Expand Down Expand Up @@ -1130,11 +1132,13 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexParallelDimensions(
hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(),
indices_parallel_dims);
const GroupedSharding operand_grouped =
hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(),
operand_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operands[0].sharding(), operand_parallel_dims),
new_indices_grouped);
const GroupedSharding update_grouped =
hlo_sharding_util::GroupShardingOnDims(updates[0].sharding(),
update_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
updates[0].sharding(), update_parallel_dims),
new_indices_grouped);
const GroupedSharding& output_grouped = operand_grouped;
std::vector<PartitionedHlo> per_group_operands =
PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);
Expand Down

0 comments on commit f6b6175

Please sign in to comment.