Skip to content

Commit

Permalink
PR #15403: Handle multiple users in all-gather dynamic-slice simplifi…
Browse files Browse the repository at this point in the history
…cation. Add AllGatherDynamicSliceSimplifier pass

Imported from GitHub PR #15403

I have found in some models that have poor SPMD partitioning the below pattern.

```
all-gather.1 = all-gather(x)
dot.1 = dot(all-gather.1, y)
dynamic-slice.1 = dynamic-slice(all-gather.1) // can be cancelled
```

In this case, the all-gather has multiple users but the dynamic-slice can be cancelled. This is applicable to all-reduce and reduce-scatter also. My changes now support multiple users, but it also depends how this utility is used by internal TPU compiler and the GPU ReduceScatterCreator pass. My changes assume the cancellation is run like this --

1. Find a dynamic-slice
2. Check if dynamic-slice can be cancelled
3. Delete dynamic-slice but do not delete the collective
4. The collective is deleted by the DCE pass if it has no users

The above workflow then supports removing dynamic-slices even if the collective has multiple users. The above is what we are using in our internal Neuron workflow.
Interested to hear thoughts on this.
Copybara import of the project:

--
f518bd6 by ptoulme-aws <[email protected]>:

Handle multiple users in all-gather dynamic-slice simplification. Add AllGatherDynamicSliceSimplifier pass

Merging this change closes #15403

COPYBARA_INTEGRATE_REVIEW=#15403 from ptoulme-aws:multiple_user_collectives f518bd6
PiperOrigin-RevId: 675370754
  • Loading branch information
ptoulme-aws authored and Google-ML-Automation committed Sep 17, 2024
1 parent 26325f0 commit 8e1e8a9
Show file tree
Hide file tree
Showing 8 changed files with 417 additions and 10 deletions.
25 changes: 19 additions & 6 deletions xla/service/collective_opt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,14 @@ bool AllGatherDynamicSliceCancellation(
const HloAllGatherInstruction* ag, int64_t num_partitions,
int64_t num_replicas, bool allow_multiple_split_dims,
bool allow_intervening_reshape, int64_t min_rank,
HloPredicate match_partition_id, HloPredicate match_replica_id) {
HloPredicate match_partition_id, HloPredicate match_replica_id,
bool allow_intervening_bitcast, bool allow_multiple_users) {
auto spec = MatchWithDynamicSlice(
ag, num_partitions, num_replicas, allow_multiple_split_dims,
allow_intervening_reshape, min_rank, match_partition_id, match_replica_id,
ag->constrain_layout(), ag->use_global_device_ids(),
ag->channel_id() && ag->opcode() == HloOpcode::kAllGather);
ag->channel_id() && ag->opcode() == HloOpcode::kAllGather,
allow_intervening_bitcast, allow_multiple_users);
if (spec.has_value()) {
return true;
}
Expand All @@ -340,7 +342,7 @@ std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
bool allow_intervening_reshape, int64_t min_rank,
HloPredicate match_partition_id, HloPredicate match_replica_id,
bool is_constrain_layout, bool use_global_device_ids, bool is_cross_module,
bool allow_intervening_bitcast) {
bool allow_intervening_bitcast, bool allow_multiple_users) {
if (!instruction->shape().IsArray() || is_constrain_layout ||
(is_cross_module &&
!instruction->GetModule()->config().use_spmd_partitioning())) {
Expand All @@ -354,8 +356,8 @@ std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
<< " excluding trivial dimensions " << instruction->ToString();
return std::nullopt;
}
if (instruction->user_count() != 1) {
VLOG(2) << "All-gather user_count > 1 " << instruction->ToString();
if (!allow_multiple_users && instruction->user_count() != 1) {
VLOG(2) << "All-gather user_count != 1 " << instruction->ToString();
return std::nullopt;
}
if (instruction->replica_groups().size() > 1) {
Expand All @@ -371,8 +373,19 @@ std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
return std::nullopt;
}
}

// Always assume first user to start.
HloInstruction* user = instruction->users()[0];
if (allow_multiple_users) {
// If we find a reshape or dynamic-slice use that.
for (auto* some_user : instruction->users()) {
if ((allow_intervening_reshape &&
some_user->opcode() == HloOpcode::kReshape) ||
some_user->opcode() == HloOpcode::kDynamicSlice) {
user = some_user;
break;
}
}
}
HloInstruction* reshape = nullptr;
if (allow_intervening_reshape && user->opcode() == HloOpcode::kReshape) {
// Allow the intervening reshape if it reshapes just the non scattered
Expand Down
10 changes: 6 additions & 4 deletions xla/service/collective_opt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,26 @@ std::optional<ReduceScatterSpec> MatchReduceScatter(
HloPredicate match_replica_id = HloPredicateIsOp<HloOpcode::kReplicaId>,
bool allow_intervening_bitcast = false);

// Check whether AG(ICI) and its single user DS(ICI) can be canceled out.
// Check whether AG(ICI) and its user DS(ICI) can be canceled out.
bool AllGatherDynamicSliceCancellation(
const HloAllGatherInstruction* ag, int64_t num_partitions,
int64_t num_replicas, bool allow_multiple_split_dims = false,
bool allow_intervening_reshape = false, int64_t min_rank = 1,
HloPredicate match_partition_id = HloPredicateIsOp<HloOpcode::kPartitionId>,
HloPredicate match_replica_id = HloPredicateIsOp<HloOpcode::kReplicaId>);
HloPredicate match_replica_id = HloPredicateIsOp<HloOpcode::kReplicaId>,
bool allow_intervening_bitcast = false, bool allow_multiple_users = false);

// Check if a given instruction (AllReduce or AllGather) matches a DynamicSlice;
// the DynamicSlice has to be the only user of the given instruction.
// the DynamicSlice has to be the user of the given instruction.
std::optional<ReduceScatterSpec> MatchWithDynamicSlice(
const HloChannelInstruction* instruction, int64_t num_partitions,
int64_t num_replicas, bool allow_multiple_split_dims = false,
bool allow_intervening_reshape = false, int64_t min_rank = 1,
HloPredicate match_partition_id = HloPredicateIsOp<HloOpcode::kPartitionId>,
HloPredicate match_replica_id = HloPredicateIsOp<HloOpcode::kReplicaId>,
bool is_constrain_layout = false, bool use_global_device_ids = false,
bool is_cross_module = false, bool allow_intervening_bitcast = false);
bool is_cross_module = false, bool allow_intervening_bitcast = false,
bool allow_multiple_users = false);

} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,7 @@ cc_library(
"//xla/service/gpu/runtime:thunk",
"//xla/service/gpu/transforms:algebraic_simplifier",
"//xla/service/gpu/transforms:algorithm_checker",
"//xla/service/gpu/transforms:all_gather_dynamic_slice_simplifier",
"//xla/service/gpu/transforms:all_gather_optimizer",
"//xla/service/gpu/transforms:all_reduce_blueconnect",
"//xla/service/gpu/transforms:all_reduce_splitter",
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ limitations under the License.
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/gpu/transforms/algebraic_simplifier.h"
#include "xla/service/gpu/transforms/algorithm_checker.h"
#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h"
#include "xla/service/gpu/transforms/all_gather_optimizer.h"
#include "xla/service/gpu/transforms/all_reduce_blueconnect.h"
#include "xla/service/gpu/transforms/all_reduce_splitter.h"
Expand Down Expand Up @@ -903,6 +904,7 @@ absl::Status RunCollectiveOptimizationPasses(
HloPassPipeline collectives_pipeline("collective-optimizations");
collectives_pipeline.AddPass<AllReduceFolder>();
collectives_pipeline.AddPass<AllReduceSplitter>();
collectives_pipeline.AddPass<AllGatherDynamicSliceSimplifier>();
collectives_pipeline.AddPass<AllGatherOptimizer>();
collectives_pipeline.AddPass<AllReduceReassociate>(
debug_options.xla_gpu_enable_reassociation_for_converted_ar());
Expand Down
25 changes: 25 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,31 @@ xla_cc_test(
],
)

cc_library(
name = "all_gather_dynamic_slice_simplifier",
srcs = ["all_gather_dynamic_slice_simplifier.cc"],
hdrs = ["all_gather_dynamic_slice_simplifier.h"],
deps = [
"//xla/hlo/ir:hlo",
"//xla/service:collective_opt_utils",
"//xla/service:hlo_creation_utils",
"//xla/service:op_expander_pass",
],
)

xla_cc_test(
name = "all_gather_dynamic_slice_simplifier_test",
srcs = ["all_gather_dynamic_slice_simplifier_test.cc"],
deps = [
":all_gather_dynamic_slice_simplifier",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"@tsl//tsl/platform:test_main",
],
)

cc_library(
name = "collective_permute_cycle_decomposer",
srcs = ["collective_permute_cycle_decomposer.cc"],
Expand Down
83 changes: 83 additions & 0 deletions xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h"

#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/service/collective_opt_utils.h"

namespace xla {
bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern(
HloInstruction* instruction) {
if (instruction->opcode() != HloOpcode::kDynamicSlice) {
return false;
}

HloDynamicSliceInstruction* dynamic_slice =
Cast<HloDynamicSliceInstruction>(instruction);
HloInstruction* operand = dynamic_slice->mutable_operand(0);

// Check if the operand is a reshape or all-gather instruction
bool is_reshape = operand->opcode() == HloOpcode::kReshape;
bool is_all_gather = operand->opcode() == HloOpcode::kAllGather;

if (!is_reshape && !is_all_gather) {
return false;
}

if (is_reshape && operand->operand(0)->opcode() != HloOpcode::kAllGather) {
return false;
}

const HloModuleConfig& config = instruction->GetModule()->config();
HloAllGatherInstruction* all_gather =
is_reshape ? Cast<HloAllGatherInstruction>(operand->mutable_operand(0))
: Cast<HloAllGatherInstruction>(operand);

bool match = AllGatherDynamicSliceCancellation(
all_gather, config.num_partitions(), config.replica_count(),
/*allow_multiple_split_dims=*/true,
/*allow_intervening_reshape=*/true, /*min_rank=*/1,
HloPredicateIsOp<HloOpcode::kPartitionId>,
HloPredicateIsOp<HloOpcode::kReplicaId>,
/*allow_intervening_bitcast=*/false,
/*allow_multiple_users=*/true);

return match;
}

StatusOr<HloInstruction*> AllGatherDynamicSliceSimplifier::ExpandInstruction(
HloInstruction* instruction) {
HloDynamicSliceInstruction* dynamic_slice =
Cast<HloDynamicSliceInstruction>(instruction);
HloInstruction* operand = dynamic_slice->mutable_operand(0);

if (operand->opcode() != HloOpcode::kReshape) {
// dynamic-slice(all-gather) case
return operand->mutable_operand(0);
}

// dynamic-slice(reshape(all-gather)) case
HloReshapeInstruction* reshape = Cast<HloReshapeInstruction>(operand);
HloAllGatherInstruction* all_gather =
Cast<HloAllGatherInstruction>(reshape->mutable_operand(0));
HloInstruction* all_gather_input = all_gather->mutable_operand(0);

auto* new_reshape = instruction->parent()->AddInstruction(
HloInstruction::CreateReshape(dynamic_slice->shape(), all_gather_input));
return new_reshape;
}

} // namespace xla
48 changes: 48 additions & 0 deletions xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_

#include "xla/service/op_expander_pass.h"

namespace xla {

// A pass that simplifies a dynamic-slice of an all-gather
// whose slice is the same as the original operand of the all-gather.
// As an example:
//
// ag = all-gather(x) replica_groups={{0,1,2,3,4,5,6,7}}
// offset = multiply(partition_id, slice_size)
// ds = dynamic-slice(ag, offset, 0, 0)
//
// Can be simplified to the all-gather operand.

class AllGatherDynamicSliceSimplifier : public OpExpanderPass {
public:
absl::string_view name() const override {
return "all-gather-dynamic-slice-simplifier";
}

protected:
bool InstructionMatchesPattern(HloInstruction* instruction) override;

StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
};

} // namespace xla

#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_
Loading

0 comments on commit 8e1e8a9

Please sign in to comment.