Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pass to remove redundant check (select) feeding into a collective permute with cycle. #17181

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,41 @@ xla_cc_test(
],
)

cc_library(
name = "collective_select_folder",
srcs = ["collective_select_folder.cc"],
hdrs = ["collective_select_folder.h"],
deps = [
"//xla:comparison_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "collective_select_folder_test",
srcs = ["collective_select_folder_test.cc"],
deps = [
":collective_select_folder",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "collective_permute_valid_iteration_annotator",
srcs = ["collective_permute_valid_iteration_annotator.cc"],
Expand Down
164 changes: 164 additions & 0 deletions xla/service/gpu/transforms/collective_select_folder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/transforms/collective_select_folder.h"

#include <cstdint>
#include <optional>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

using SourceTargetPair = std::pair<int64_t, int64_t>;
using SourceTargetPairs = std::vector<SourceTargetPair>;

struct SelectPredInfo {
int64_t constant;
Comparison::Direction direction;
HloOpcode device_id_type; // kReplicaId or kPartitionId
HloInstruction* true_operand;
HloInstruction* false_operand;
};

// Returns handy references to %constant, %true_operand, %false_operand of the
// select(broadcast(compare(current_device_id, constant)), true_operand,
// false_operand)
// or
// select(compare(current_device_id, constant), true_operand,
// false_operand)
std::optional<SelectPredInfo> GetPredSelectInfo(HloInstruction* select) {
if (select->opcode() != HloOpcode::kSelect) {
return std::nullopt;
}

// Select may have broadcast.
const HloInstruction* compare_candidate = select->operand(0);
if (compare_candidate->opcode() != HloOpcode::kCompare) {
compare_candidate = compare_candidate->operand(0);
}
if (compare_candidate->opcode() != HloOpcode::kCompare) {
return std::nullopt;
}

const HloCompareInstruction* compare =
DynCast<HloCompareInstruction>(compare_candidate);

if ((compare->operand(0)->opcode() != HloOpcode::kReplicaId &&
compare->operand(0)->opcode() != HloOpcode::kPartitionId) ||
compare->operand(1)->opcode() != HloOpcode::kConstant) {
return std::nullopt;
}

int64_t id_value =
compare->operand(1)->literal().GetFirstInteger().value_or(-1);

return SelectPredInfo{id_value, compare->direction(),
compare->operand(0)->opcode(),
select->mutable_operand(1), select->mutable_operand(2)};
}

bool IsUniqueSource(int64_t device_id, const SourceTargetPairs& pairs) {
if (pairs.size() == 1 && pairs[0].first == device_id) return true;
return false;
}

bool IsNotPresentInSource(int64_t device_id, const SourceTargetPairs& pairs) {
return absl::c_none_of(
pairs, [device_id](const auto& pair) { return pair.first == device_id; });
}

inline absl::StatusOr<bool> update(HloInstruction* cp, HloInstruction* data) {
TF_RETURN_IF_ERROR(cp->ReplaceOperandWith(0, data));
return true;
}

// We have to maintain integrity of relationship between partition/replica
// and collective-permute's channel_id.
// That is we can only fold select when
// 1. cp has channel_id and condition is based on partition_id
// 2. cp has no channel_id and condition is based on replica_id
// See enum class CollectiveOpGroupMode for details.
bool IsShardingConsistent(HloCollectivePermuteInstruction* cp,
HloOpcode device_id_type) {
auto id = cp->channel_id();
return (device_id_type == HloOpcode::kPartitionId && id.has_value()) ||
(device_id_type == HloOpcode::kReplicaId && !id.has_value());
}

// Recognizes the pattern and update if applicable.
absl::StatusOr<bool> TryFoldSelect(HloInstruction* in) {
if (in->opcode() != HloOpcode::kCollectivePermute) return false;
auto select_info_opt = GetPredSelectInfo(in->mutable_operand(0));
if (!select_info_opt.has_value()) return false;
auto select_info = select_info_opt.value();

HloCollectivePermuteInstruction* cp =
Cast<HloCollectivePermuteInstruction>(in);
if (!IsShardingConsistent(cp, select_info.device_id_type)) return false;

int64_t device_id = select_info.constant;
SourceTargetPairs pairs = cp->source_target_pairs();

if (select_info.direction == Comparison::Direction::kEq) {
if (IsUniqueSource(device_id, pairs)) {
return update(cp, select_info.true_operand);
} else if (IsNotPresentInSource(device_id, pairs)) {
return update(cp, select_info.false_operand);
}
}

if (select_info.direction == Comparison::Direction::kNe) {
if (IsNotPresentInSource(device_id, pairs)) {
return update(cp, select_info.true_operand);
} else if (IsUniqueSource(device_id, pairs)) {
return update(cp, select_info.false_operand);
}
}
return false;
}

} // namespace

absl::StatusOr<bool> CollectiveSelectFolder::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
TF_ASSIGN_OR_RETURN(bool local_changed, TryFoldSelect(instruction));
changed |= local_changed;
}
}
return changed;
}

} // namespace xla
69 changes: 69 additions & 0 deletions xla/service/gpu/transforms/collective_select_folder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

namespace xla {

// When collective-permute operates on a comparison to a device id
// and the senders match the condition's branch
// we can link collective-permute to the original data skipping the comparison.
// For example
// condition = broadcast(compare(replica_id, X), direction=EQ
// data_snd = select(condition, compare_true_data, compare_false_data)
// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}}
// can be transformed to
// rcv = collective-permute(compare_true_data), pairs={{X,0}}
//
// The pass is *only* handling compare direction={EQ,NE}.
// The pass handles Compare with and without preceding Broadcast.
//
// This pass is particularly useful in the pipeline parallelism generated module
// such as:
// fwd_data = ...
// bwd_data =
// is_first_device = ...
// is_last_device = ...
// data_snd = select(is_last_device, bwd_data, fwd_data)
// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}}
// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}}
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv)
//
// After the transformation, the module will become:
// fwd_data_snd = ...
// bwd_data_snd = ...
// is_first_device = ...
// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}}
// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}}
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv)
class CollectiveSelectFolder : public HloModulePass {
public:
absl::string_view name() const override { return "collective-select-folder"; }

absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

} // namespace xla

#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_
Loading
Loading