diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index c8d7ead83f1d1..c3cd98219899e 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -2319,7 +2319,6 @@ std::optional GetGatherScatterBatchParallelDims( index_parallel_in_dim[i] = -1; } } - absl::c_sort(indices_parallel_dims); if (!indices_parallel_dims.empty()) { return GatherScatterParallelDims{ indices_parallel_dims, operand_parallel_dims, index_parallel_in_dim}; @@ -2362,15 +2361,18 @@ GetGatherOutputOrScatterUpdateParallelDims( int64_t index_vector_dim, absl::Span offset_or_window_dims) { absl::InlinedVector output_parallel_dims; auto indices_parallel_dims = parallel_dim.indices_parallel_dims; - for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) { - if (absl::c_linear_search(offset_or_window_dims, i)) { - continue; - } - const int index_dim = idx_dim < index_vector_dim ? idx_dim : idx_dim + 1; - if (absl::c_binary_search(indices_parallel_dims, index_dim)) { - output_parallel_dims.push_back(i); + for (int64_t indices_parallel_dim : indices_parallel_dims) { + for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) { + if (absl::c_linear_search(offset_or_window_dims, i)) { + continue; + } + const int index_dim = idx_dim < index_vector_dim ? idx_dim : idx_dim + 1; + if (indices_parallel_dim == index_dim) { + output_parallel_dims.push_back(i); + break; + } + ++idx_dim; } - ++idx_dim; } return output_parallel_dims; } diff --git a/xla/service/spmd/gather_scatter_handler.cc b/xla/service/spmd/gather_scatter_handler.cc index 21cc82fc0ddf1..fd45e649f18cd 100644 --- a/xla/service/spmd/gather_scatter_handler.cc +++ b/xla/service/spmd/gather_scatter_handler.cc @@ -684,11 +684,13 @@ absl::StatusOr PartitionGatherIndexParallelDimensions( hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), indices_parallel_dims); const GroupedSharding operand_grouped = - hlo_sharding_util::GroupShardingOnDims(operand.sharding(), - operand_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operand.sharding(), operand_parallel_dims), + new_indices_grouped); const GroupedSharding output_grouped = - hlo_sharding_util::GroupShardingOnDims(gather_output_sharding, - output_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + gather_output_sharding, output_parallel_dims), + new_indices_grouped); PartitionedHlo per_group_operand = PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups); PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( @@ -1130,11 +1132,13 @@ absl::StatusOr PartitionScatterIndexParallelDimensions( hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), indices_parallel_dims); const GroupedSharding operand_grouped = - hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(), - operand_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operands[0].sharding(), operand_parallel_dims), + new_indices_grouped); const GroupedSharding update_grouped = - hlo_sharding_util::GroupShardingOnDims(updates[0].sharding(), - update_parallel_dims); + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + updates[0].sharding(), update_parallel_dims), + new_indices_grouped); const GroupedSharding& output_grouped = operand_grouped; std::vector per_group_operands = PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);