-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a pass to remove redundant check (select) feeding into a collecti…
…ve permute with cycle. When collective-permute operates on a comparison to a device id and the senders match the condition's branch we can link collective-permute to the original data skipping the comparison. For example condition = broadcast(compare(replica_id, X), direction=EQ data_snd = select(condition, compare_true_data, compare_false_data) rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}} can be transformed to rcv = collective-permute(compare_true_data), pairs={{X,0}} The pass is *only* handling compare direction={EQ,NE}. The pass handles Compare with and without preceding Broadcast. PiperOrigin-RevId: 676181937
- Loading branch information
1 parent
7640789
commit a050924
Showing
4 changed files
with
586 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "xla/service/gpu/transforms/collective_select_folder.h" | ||
|
||
#include <cstdint> | ||
#include <optional> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "absl/algorithm/container.h" | ||
#include "absl/container/flat_hash_set.h" | ||
#include "absl/log/check.h" | ||
#include "absl/status/statusor.h" | ||
#include "absl/strings/string_view.h" | ||
#include "xla/comparison_util.h" | ||
#include "xla/hlo/ir/hlo_casting_utils.h" | ||
#include "xla/hlo/ir/hlo_computation.h" | ||
#include "xla/hlo/ir/hlo_instruction.h" | ||
#include "xla/hlo/ir/hlo_instructions.h" | ||
#include "xla/hlo/ir/hlo_module.h" | ||
#include "xla/hlo/ir/hlo_opcode.h" | ||
#include "tsl/platform/errors.h" | ||
#include "tsl/platform/statusor.h" | ||
|
||
namespace xla { | ||
namespace { | ||
|
||
using SourceTargetPair = std::pair<int64_t, int64_t>; | ||
using SourceTargetPairs = std::vector<SourceTargetPair>; | ||
|
||
struct SelectPredInfo { | ||
int64_t constant; | ||
Comparison::Direction direction; | ||
HloOpcode device_id_type; // kReplicaId or kPartitionId | ||
HloInstruction* true_operand; | ||
HloInstruction* false_operand; | ||
}; | ||
|
||
// Returns handy references to %constant, %true_operand, %false_operand of the | ||
// select(broadcast(compare(current_device_id, constant)), true_operand, | ||
// false_operand) | ||
// or | ||
// select(compare(current_device_id, constant), true_operand, | ||
// false_operand) | ||
std::optional<SelectPredInfo> GetPredSelectInfo(HloInstruction* select) { | ||
if (select->opcode() != HloOpcode::kSelect) { | ||
return std::nullopt; | ||
} | ||
|
||
// Select may have broadcast. | ||
const HloInstruction* compare_candidate = select->operand(0); | ||
if (compare_candidate->opcode() != HloOpcode::kCompare) { | ||
compare_candidate = compare_candidate->operand(0); | ||
} | ||
if (compare_candidate->opcode() != HloOpcode::kCompare) { | ||
return std::nullopt; | ||
} | ||
|
||
const HloCompareInstruction* compare = | ||
DynCast<HloCompareInstruction>(compare_candidate); | ||
|
||
if ((compare->operand(0)->opcode() != HloOpcode::kReplicaId && | ||
compare->operand(0)->opcode() != HloOpcode::kPartitionId) || | ||
compare->operand(1)->opcode() != HloOpcode::kConstant) { | ||
return std::nullopt; | ||
} | ||
|
||
int64_t id_value = | ||
compare->operand(1)->literal().GetFirstInteger().value_or(-1); | ||
|
||
return SelectPredInfo{id_value, compare->direction(), | ||
compare->operand(0)->opcode(), | ||
select->mutable_operand(1), select->mutable_operand(2)}; | ||
} | ||
|
||
bool IsUniqueSource(int64_t device_id, const SourceTargetPairs& pairs) { | ||
if (pairs.size() == 1 && pairs[0].first == device_id) return true; | ||
return false; | ||
} | ||
|
||
bool IsNotPresentInSource(int64_t device_id, const SourceTargetPairs& pairs) { | ||
return absl::c_none_of( | ||
pairs, [device_id](const auto& pair) { return pair.first == device_id; }); | ||
} | ||
|
||
inline absl::StatusOr<bool> update(HloInstruction* cp, HloInstruction* data) { | ||
TF_RETURN_IF_ERROR(cp->ReplaceOperandWith(0, data)); | ||
return true; | ||
} | ||
|
||
// We have to maintain integrity of relationship between partition/replica | ||
// and collective-permute's channel_id. | ||
// That is we can only fold select when | ||
// 1. cp has channel_id and condition is based on partition_id | ||
// 2. cp has no channel_id and condition is based on replica_id | ||
// See enum class CollectiveOpGroupMode for details. | ||
bool IsShardingConsistent(HloCollectivePermuteInstruction* cp, | ||
HloOpcode device_id_type) { | ||
auto id = cp->channel_id(); | ||
return (device_id_type == HloOpcode::kPartitionId && id.has_value()) || | ||
(device_id_type == HloOpcode::kReplicaId && !id.has_value()); | ||
} | ||
|
||
// Recognizes the pattern and update if applicable. | ||
absl::StatusOr<bool> TryFoldSelect(HloInstruction* in) { | ||
if (in->opcode() != HloOpcode::kCollectivePermute) return false; | ||
auto select_info_opt = GetPredSelectInfo(in->mutable_operand(0)); | ||
if (!select_info_opt.has_value()) return false; | ||
auto select_info = select_info_opt.value(); | ||
|
||
HloCollectivePermuteInstruction* cp = | ||
Cast<HloCollectivePermuteInstruction>(in); | ||
if (!IsShardingConsistent(cp, select_info.device_id_type)) return false; | ||
|
||
int64_t device_id = select_info.constant; | ||
SourceTargetPairs pairs = cp->source_target_pairs(); | ||
|
||
if (select_info.direction == Comparison::Direction::kEq) { | ||
if (IsUniqueSource(device_id, pairs)) { | ||
return update(cp, select_info.true_operand); | ||
} else if (IsNotPresentInSource(device_id, pairs)) { | ||
return update(cp, select_info.false_operand); | ||
} | ||
} | ||
|
||
if (select_info.direction == Comparison::Direction::kNe) { | ||
if (IsNotPresentInSource(device_id, pairs)) { | ||
return update(cp, select_info.true_operand); | ||
} else if (IsUniqueSource(device_id, pairs)) { | ||
return update(cp, select_info.false_operand); | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
} // namespace | ||
|
||
absl::StatusOr<bool> CollectiveSelectFolder::Run( | ||
HloModule* module, | ||
const absl::flat_hash_set<absl::string_view>& execution_threads) { | ||
bool changed = false; | ||
for (HloComputation* computation : module->computations()) { | ||
for (HloInstruction* instruction : computation->instructions()) { | ||
TF_ASSIGN_OR_RETURN(bool local_changed, TryFoldSelect(instruction)); | ||
changed |= local_changed; | ||
} | ||
} | ||
return changed; | ||
} | ||
|
||
} // namespace xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* Copyright 2024 The OpenXLA Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_ | ||
#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_ | ||
|
||
#include "absl/container/flat_hash_set.h" | ||
#include "absl/status/statusor.h" | ||
#include "absl/strings/string_view.h" | ||
#include "xla/hlo/ir/hlo_module.h" | ||
#include "xla/hlo/pass/hlo_pass_interface.h" | ||
|
||
namespace xla { | ||
|
||
// When collective-permute operates on a comparison to a device id | ||
// and the senders match the condition's branch | ||
// we can link collective-permute to the original data skipping the comparison. | ||
// For example | ||
// condition = broadcast(compare(replica_id, X), direction=EQ | ||
// data_snd = select(condition, compare_true_data, compare_false_data) | ||
// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}} | ||
// can be transformed to | ||
// rcv = collective-permute(compare_true_data), pairs={{X,0}} | ||
// | ||
// The pass is *only* handling compare direction={EQ,NE}. | ||
// The pass handles Compare with and without preceding Broadcast. | ||
// | ||
// This pass is particularly useful in the pipeline parallelism generated module | ||
// such as: | ||
// fwd_data = ... | ||
// bwd_data = | ||
// is_first_device = ... | ||
// is_last_device = ... | ||
// data_snd = select(is_last_device, bwd_data, fwd_data) | ||
// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}} | ||
// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}} | ||
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) | ||
// | ||
// After the transformation, the module will become: | ||
// fwd_data_snd = ... | ||
// bwd_data_snd = ... | ||
// is_first_device = ... | ||
// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}} | ||
// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}} | ||
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) | ||
class CollectiveSelectFolder : public HloModulePass { | ||
public: | ||
absl::string_view name() const override { return "collective-select-folder"; } | ||
|
||
absl::StatusOr<bool> Run( | ||
HloModule* module, | ||
const absl::flat_hash_set<absl::string_view>& execution_threads) override; | ||
}; | ||
|
||
} // namespace xla | ||
|
||
#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_ |
Oops, something went wrong.