diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index 5d1002168daf2..e6248154036cb 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -1078,6 +1078,80 @@ bool ContainsTileSharding(const HloModule& module) { return false; } +template +std::vector argsort(absl::Span data) { + std::vector indices(data.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&data](int64_t i, int64_t j) { return data[i] < data[j]; }); + return indices; +} + +// Given a `source_sharding`, preserve the tiles along the `source_dims` and +// replicate the rest. The `target_dims` are used to determine the order of the +// dimensions in the resulting sharding. If `source_dims` and `target_dims` are +// in the different order (i.e., different argsort results), we need to +// transpose the tile assignment. +// +// Given the following input, +// * source_sharding = {devices=[2,3,5,7,11]<=[2310]} +// * source_dims = [2, 4, 1] +// * target_dims = [2, 1, 3] +// * target_shape_rank = 5 +// The result shoule be {devices=[1,11,5,3,1,14]<=[2,3,5,7,11]T(4,2,1,0,3) +// last_tile_dim_replicate}. +HloSharding PropagateShardingAlongDimsAndReplicateOthers( + const HloSharding& source_sharding, absl::Span source_dims, + absl::Span target_dims, int64_t target_shape_rank) { + CHECK_EQ(source_dims.size(), target_dims.size()); + if (source_sharding.IsTileMaximal() || source_sharding.IsManual()) { + return source_sharding; + } + + HloSharding replicate_other_dims = + PartiallyReplicateTiledShardingOnAllDimsExcept(source_sharding, + source_dims); + if (replicate_other_dims.IsTileMaximal()) { + return replicate_other_dims; + } + + std::vector argsort_source_dims = argsort(source_dims); + std::vector argsort_target_dims = argsort(target_dims); + if (argsort_source_dims != argsort_target_dims) { + std::vector perm( + replicate_other_dims.tile_assignment().num_dimensions(), -1); + for (int64_t i = 0; i < source_dims.size(); ++i) { + perm[source_dims[argsort_target_dims[i]]] = i; + } + int64_t i = source_dims.size(); + for (int64_t& perm_element : perm) { + if (perm_element == -1) { + perm_element = i++; + } + } + replicate_other_dims = TransposeSharding(replicate_other_dims, perm); + } + + std::vector target_tile_dims(target_shape_rank, 1); + for (int i = 0; i < source_dims.size(); ++i) { + target_tile_dims[target_dims[i]] = + source_sharding.tile_assignment().dim(source_dims[i]); + } + for (int64_t i = replicate_other_dims.TiledDataRank(); + i < replicate_other_dims.tile_assignment().num_dimensions(); ++i) { + target_tile_dims.push_back(replicate_other_dims.tile_assignment().dim(i)); + } + + auto target_tile_assignment = + replicate_other_dims.tile_assignment().Reshape(target_tile_dims); + return replicate_other_dims.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(target_tile_assignment, + replicate_other_dims.metadata()) + : HloSharding::Subgroup(target_tile_assignment, + replicate_other_dims.subgroup_types(), + replicate_other_dims.metadata()); +} + HloSharding GatherOutputShardingFromIndexIndexPassthroughDimensions( const HloSharding& index_sharding, const HloInstruction* hlo) { CHECK(hlo->opcode() == HloOpcode::kGather); @@ -1559,71 +1633,37 @@ std::optional GatherOperandShardingFromOutputParallelDimensions( if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) { return output_sharding; } - auto parallel_dims = GetGatherParallelBatchDims(gather, call_graph); - if (parallel_dims) { - auto output_parallel_dims = - GetGatherParallelOutputDims(gather, *parallel_dims); - auto output_aligned_operand_parallel_dims = - parallel_dims->operand_parallel_dims; - const Shape gather_shape = gather.shape(); - CHECK_EQ(output_parallel_dims.size(), - output_aligned_operand_parallel_dims.size()); - DimensionVector operand_tile_assignment(gather.operand(0)->shape().rank(), - 1); - DimensionVector relevant_output_dims; - for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) { - if (parallel_idx >= output_parallel_dims.size() || - output_parallel_dims[parallel_idx] != i) { - continue; - } - const int64_t operand_dim = - output_aligned_operand_parallel_dims[parallel_idx++]; - operand_tile_assignment[operand_dim] = - output_sharding.tile_assignment().dim(i); - relevant_output_dims.push_back(i); - } - HloSharding relevant_output_sharding = - PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding, - relevant_output_dims); - if (relevant_output_sharding.IsTileMaximal()) { - return std::move(relevant_output_sharding); - } - - for (int64_t i = relevant_output_sharding.TiledDataRank(); - i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) { - operand_tile_assignment.push_back( - relevant_output_sharding.tile_assignment().dim(i)); - } - auto tile_assignment = relevant_output_sharding.tile_assignment().Reshape( - operand_tile_assignment); - return relevant_output_sharding.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(tile_assignment, - output_sharding.metadata()) - : HloSharding::Subgroup( - tile_assignment, relevant_output_sharding.subgroup_types(), - output_sharding.metadata()); + + GatherScatterParallelDims parallel_dims; + + const GatherDimensionNumbers& dnums = gather.gather_dimension_numbers(); + if (!dnums.operand_batching_dims().empty()) { + parallel_dims.operand_parallel_dims.assign( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + } + if (std::optional implicit_parallel_dims = + GetGatherParallelBatchDims(gather, call_graph)) { + parallel_dims.operand_parallel_dims.insert( + parallel_dims.operand_parallel_dims.end(), + implicit_parallel_dims->operand_parallel_dims.begin(), + implicit_parallel_dims->operand_parallel_dims.end()); + parallel_dims.indices_parallel_dims.insert( + parallel_dims.indices_parallel_dims.end(), + implicit_parallel_dims->indices_parallel_dims.begin(), + implicit_parallel_dims->indices_parallel_dims.end()); } - return std::nullopt; -} -// Reorders `to_align` based on the order of how `target_permuted` is reordered -// from `target`, expecting the container size to be small. -absl::InlinedVector AlignSmallContainers( - absl::Span to_align, absl::Span target, - absl::Span target_permuted) { - CHECK(absl::c_is_permutation(target_permuted, target)); - CHECK_EQ(to_align.size(), target.size()); - absl::InlinedVector to_align_permuted(to_align.size()); - for (auto i = 0; i < target.size(); ++i) { - // This is small so just look linearly. - for (auto j = 0; j < target_permuted.size(); ++j) { - if (target_permuted[j] == target[i]) { - to_align_permuted[j] = to_align[i]; - break; - } - } + if (parallel_dims.operand_parallel_dims.empty()) { + return std::nullopt; } - return to_align_permuted; + + return PropagateShardingAlongDimsAndReplicateOthers( + output_sharding, GetGatherParallelOutputDims(gather, parallel_dims), + parallel_dims.operand_parallel_dims, gather.operand(0)->shape().rank()); } } // namespace @@ -1776,58 +1816,37 @@ std::optional ScatterUpdateShardingFromOutputParallelDimensions( if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) { return output_sharding; } - auto parallel_dims = GetScatterParallelBatchDims(scatter, call_graph); - if (parallel_dims) { - auto update_parallel_dims = - GetScatterParallelUpdateDims(scatter, *parallel_dims); - auto index_aligned_operand_parallel_dims = - parallel_dims->operand_parallel_dims; - auto operand_parallel_dims_sorted = index_aligned_operand_parallel_dims; - absl::c_sort(operand_parallel_dims_sorted); - auto operand_aligned_update_parallel_dims = AlignSmallContainers( - update_parallel_dims, index_aligned_operand_parallel_dims, - operand_parallel_dims_sorted); - const Shape scatter_shape = scatter.shape().IsTuple() - ? scatter.shape().tuple_shapes()[0] - : scatter.shape(); - CHECK_EQ(update_parallel_dims.size(), - index_aligned_operand_parallel_dims.size()); - DimensionVector update_tile_assignment( - scatter.scatter_updates()[0]->shape().rank(), 1); - DimensionVector relevant_output_dims; - for (int i = 0, parallel_idx = 0; i < scatter_shape.rank(); ++i) { - if (parallel_idx >= operand_parallel_dims_sorted.size() || - operand_parallel_dims_sorted[parallel_idx] != i) { - continue; - } - const int64_t update_dim = - operand_aligned_update_parallel_dims[parallel_idx++]; - update_tile_assignment[update_dim] = - output_sharding.tile_assignment().dim(i); - relevant_output_dims.push_back(i); - } - HloSharding relevant_output_sharding = - PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding, - relevant_output_dims); - if (relevant_output_sharding.IsTileMaximal()) { - return std::move(relevant_output_sharding); - } - - for (int64_t i = relevant_output_sharding.TiledDataRank(); - i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) { - update_tile_assignment.push_back( - relevant_output_sharding.tile_assignment().dim(i)); - } - auto tile_assignment = relevant_output_sharding.tile_assignment().Reshape( - update_tile_assignment); - return relevant_output_sharding.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(tile_assignment, - output_sharding.metadata()) - : HloSharding::Subgroup( - tile_assignment, relevant_output_sharding.subgroup_types(), - output_sharding.metadata()); + + GatherScatterParallelDims parallel_dims; + + const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers(); + if (!dnums.input_batching_dims().empty()) { + parallel_dims.operand_parallel_dims.assign( + dnums.input_batching_dims().begin(), dnums.input_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + } + if (std::optional implicit_parallel_dims = + GetScatterParallelBatchDims(scatter, call_graph)) { + parallel_dims.operand_parallel_dims.insert( + parallel_dims.operand_parallel_dims.end(), + implicit_parallel_dims->operand_parallel_dims.begin(), + implicit_parallel_dims->operand_parallel_dims.end()); + parallel_dims.indices_parallel_dims.insert( + parallel_dims.indices_parallel_dims.end(), + implicit_parallel_dims->indices_parallel_dims.begin(), + implicit_parallel_dims->indices_parallel_dims.end()); + } + + if (parallel_dims.operand_parallel_dims.empty()) { + return std::nullopt; } - return std::nullopt; + + return PropagateShardingAlongDimsAndReplicateOthers( + output_sharding, parallel_dims.operand_parallel_dims, + GetScatterParallelUpdateDims(scatter, parallel_dims), + scatter.scatter_updates()[0]->shape().rank()); } HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( @@ -2384,6 +2403,7 @@ GetGatherOutputOrScatterUpdateParallelDims( ++idx_dim; } } + CHECK_EQ(output_parallel_dims.size(), indices_parallel_dims.size()); return output_parallel_dims; } @@ -2478,82 +2498,14 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims( return passthrough_dims; } -template -std::vector argsort(absl::Span data) { - std::vector indices(data.size()); - std::iota(indices.begin(), indices.end(), 0); - std::sort(indices.begin(), indices.end(), - [&data](int64_t i1, int64_t i2) { return data[i1] < data[i2]; }); - return indices; -} - HloSharding InferGatherScatterParallelShardingFromOperandSharding( const HloSharding& operand_sharding, const Shape& operand_shape, const Shape& shape, absl::Span output_aligned_operand_parallel_dims, absl::Span output_parallel_dims) { - if (operand_sharding.IsTileMaximal()) { - return operand_sharding; - } - - HloSharding replicate_non_parallel_dims = - PartiallyReplicateTiledShardingOnAllDimsExcept( - operand_sharding, output_aligned_operand_parallel_dims); - if (replicate_non_parallel_dims.IsTileMaximal()) { - return replicate_non_parallel_dims; - } - - // output_aligned_operand_parallel_dims and output_parallel_dims may not be - // in the same order. We need to transpose the sharding accordingly. For - // example, if output_aligned_operand_parallel_dims = [2, 4, 1] and - // output_parallel_dims = [2, 1, 3], the sharding needs to be transposed with - // perm = [3, 2, 1, 4, 0] to adjust the order of devices. - std::vector argsort_output_aligned_operand_parallel_dims = - argsort(output_aligned_operand_parallel_dims); - std::vector argsort_output_parallel_dims = - argsort(output_parallel_dims); - if (argsort_output_aligned_operand_parallel_dims != - argsort_output_parallel_dims) { - std::vector perm( - replicate_non_parallel_dims.tile_assignment().num_dimensions(), -1); - for (int64_t i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) { - perm[output_aligned_operand_parallel_dims - [argsort_output_parallel_dims[i]]] = i; - } - int64_t i = output_aligned_operand_parallel_dims.size(); - for (int64_t& perm_element : perm) { - if (perm_element == -1) { - perm_element = i++; - } - } - replicate_non_parallel_dims = - TransposeSharding(replicate_non_parallel_dims, perm); - } - - // Collect tile dimensions in the operand. - std::vector output_tile_dims(shape.rank(), 1); - for (int i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) { - const int64_t operand_idx = output_aligned_operand_parallel_dims[i]; - const int64_t output_idx = output_parallel_dims[i]; - output_tile_dims[output_idx] = - operand_sharding.tile_assignment().dim(operand_idx); - } - for (int64_t i = replicate_non_parallel_dims.TiledDataRank(); - i < replicate_non_parallel_dims.tile_assignment().num_dimensions(); - ++i) { - output_tile_dims.push_back( - replicate_non_parallel_dims.tile_assignment().dim(i)); - } - - auto output_tile_assignment = - replicate_non_parallel_dims.tile_assignment().Reshape(output_tile_dims); - return replicate_non_parallel_dims.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(output_tile_assignment, - replicate_non_parallel_dims.metadata()) - : HloSharding::Subgroup( - output_tile_assignment, - replicate_non_parallel_dims.subgroup_types(), - replicate_non_parallel_dims.metadata()); + return PropagateShardingAlongDimsAndReplicateOthers( + operand_sharding, output_aligned_operand_parallel_dims, + output_parallel_dims, shape.rank()); } std::string GroupedSharding::ToString() const { diff --git a/xla/service/sharding_propagation_test.cc b/xla/service/sharding_propagation_test.cc index f46e4e1a04c1c..f897ffb1e30f9 100644 --- a/xla/service/sharding_propagation_test.cc +++ b/xla/service/sharding_propagation_test.cc @@ -6539,6 +6539,39 @@ ENTRY entry { op::Sharding("{devices=[2,2,2,1,2]<=[16] last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, GatherBackwardWithExplicitBatchDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1) + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4}, + sharding={devices=[2,2,2,2]<=[16]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{devices=[2,1,2,2,2]<=[2,2,2,2]T(1,0,3,2) " + "last_tile_dim_replicate}")); + EXPECT_THAT( + module->entry_computation()->parameter_instruction(1), + op::Sharding("{devices=[2,2,2,1,2]<=[16] last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromOperandToResult) { const char* const hlo_string = R"( HloModule module @@ -6638,6 +6671,47 @@ ENTRY entry { "last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, ScatterBackwardWithExplicitBatchDims) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1) + %updates = f32[14,10,6,4] parameter(2) + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[2,2,2,2]<=[16]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true, true, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{devices=[2,2,2,2]<=[16]}")); + EXPECT_THAT(module->entry_computation()->parameter_instruction(1), + op::Sharding("{devices=[2,2,1,1,4]<=[2,2,2,2]T(2,0,1,3) " + "last_tile_dim_replicate}")); + EXPECT_THAT(module->entry_computation()->parameter_instruction(2), + op::Sharding("{devices=[2,2,1,2,2]<=[2,2,2,2]T(2,0,3,1) " + "last_tile_dim_replicate}")); +} + TEST_P(ParameterizedMetadataTest, ParallelGatherFromOperandForwardPass) { const char* const hlo_string = R"( HloModule module