Skip to content

Commit

Permalink
Add a pass to remove redundant check (select) feeding into a collecti…
Browse files Browse the repository at this point in the history
…ve permute with cycle.

When collective-permute operates on a comparison to a device id and the senders match the condition's branch we can link collective-permute to the original data skipping the comparison.
For example
   condition = broadcast(compare(replica_id, X), direction=EQ
   data_snd = select(condition, compare_true_data, compare_false_data)
   rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}}
 can be transformed to
   rcv = collective-permute(compare_true_data), pairs={{X,0}}

 The pass is *only* handling compare direction={EQ,NE}.
 The pass handles Compare with and without preceding Broadcast.

PiperOrigin-RevId: 674474635
  • Loading branch information
toli-y authored and Google-ML-Automation committed Sep 17, 2024
1 parent bc1aad8 commit 67b9ac3
Show file tree
Hide file tree
Showing 4 changed files with 583 additions and 0 deletions.
34 changes: 34 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,40 @@ xla_cc_test(
],
)

cc_library(
name = "collective_select_folder",
srcs = ["collective_select_folder.cc"],
hdrs = ["collective_select_folder.h"],
deps = [
"//xla:comparison_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "collective_select_folder_test",
srcs = ["collective_select_folder_test.cc"],
deps = [
":collective_select_folder",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "collective_permute_valid_iteration_annotator",
srcs = ["collective_permute_valid_iteration_annotator.cc"],
Expand Down
164 changes: 164 additions & 0 deletions xla/service/gpu/transforms/collective_select_folder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/transforms/collective_select_folder.h"

#include <cstdint>
#include <optional>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/comparison_util.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

using SourceTargetPair = std::pair<int64_t, int64_t>;
using SourceTargetPairs = std::vector<SourceTargetPair>;

struct SelectPredInfo {
int64_t constant;
Comparison::Direction direction;
HloOpcode device_id_type; // kReplicaId or kPartitionId
HloInstruction* true_operand;
HloInstruction* false_operand;
};

// Returns handy references to %constant, %true_operand, %false_operand of the
// select(broadcast(compare(current_device_id, constant)), true_operand,
// false_operand)
// or
// select(compare(current_device_id, constant), true_operand,
// false_operand)
std::optional<SelectPredInfo> GetPredSelectInfo(HloInstruction* select) {
if (select->opcode() != HloOpcode::kSelect) {
return std::nullopt;
}

const HloCompareInstruction* compare;
if (select->operand(0)->opcode() == HloOpcode::kCompare) {
compare = Cast<HloCompareInstruction>(select->operand(0));
} else if (select->operand(0)->opcode() == HloOpcode::kBroadcast &&
select->operand(0)->operand(0)->opcode() == HloOpcode::kCompare) {
compare = Cast<HloCompareInstruction>(select->operand(0)->operand(0));
} else {
return std::nullopt;
}

bool is_replica_or_partition_compare =
(compare->operand(0)->opcode() == HloOpcode::kReplicaId ||
compare->operand(0)->opcode() == HloOpcode::kPartitionId) &&
compare->operand(1)->opcode() == HloOpcode::kConstant;

if (!is_replica_or_partition_compare) return std::nullopt;

int64_t id_value =
compare->operand(1)->literal().GetFirstInteger().value_or(-1);

return SelectPredInfo{id_value, compare->direction(),
compare->operand(0)->opcode(),
select->mutable_operand(1), select->mutable_operand(2)};
}

bool IsUniqueSource(int64_t device_id, const SourceTargetPairs& pairs) {
if (pairs.size() == 1 && pairs[0].first == device_id) return true;
return false;
}

bool IsNotPresentInSource(int64_t device_id, const SourceTargetPairs& pairs) {
for (const auto& pair : pairs) {
if (pair.first == device_id) return false;
}
return true;
}

inline absl::StatusOr<bool> update(HloInstruction* cp, HloInstruction* data) {
TF_RETURN_IF_ERROR(cp->ReplaceOperandWith(0, data));
return true;
}

// We have to maintain integrity of relationship between partition/replica
// and collective-permute's channel_id.
// That is we can only fold select when
// 1. cp has channel_id and condition is based on partition_id
// 2. cp has no channel_id and condition is based on replica_id
// see enum class CollectiveOpGroupMode for details.
bool IsShardingConsistent(HloCollectivePermuteInstruction* cp,
HloOpcode device_id_type) {
auto id = cp->channel_id();
return (device_id_type == HloOpcode::kPartitionId && id.has_value()) ||
(device_id_type == HloOpcode::kReplicaId && !id.has_value());
}

// Recognizer the pattern and update if applicable.
absl::StatusOr<bool> TryFoldSelect(HloInstruction* in) {
if (in->opcode() != HloOpcode::kCollectivePermute) return false;
auto select_info_opt = GetPredSelectInfo(in->mutable_operand(0));
if (!select_info_opt.has_value()) return false;
auto select_info = select_info_opt.value();

HloCollectivePermuteInstruction* cp =
Cast<HloCollectivePermuteInstruction>(in);
if (!IsShardingConsistent(cp, select_info.device_id_type)) return false;

int64_t device_id = select_info.constant;
SourceTargetPairs pairs = cp->source_target_pairs();

if (select_info.direction == Comparison::Direction::kEq) {
if (IsUniqueSource(device_id, pairs)) {
return update(cp, select_info.true_operand);
} else if (IsNotPresentInSource(device_id, pairs)) {
return update(cp, select_info.false_operand);
}
}

if (select_info.direction == Comparison::Direction::kNe) {
if (IsNotPresentInSource(device_id, pairs)) {
return update(cp, select_info.true_operand);
} else if (IsUniqueSource(device_id, pairs)) {
return update(cp, select_info.false_operand);
}
}
return false;
}

} // namespace

absl::StatusOr<bool> CollectiveSelectFolder::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
TF_ASSIGN_OR_RETURN(bool local_changed, TryFoldSelect(instruction));
changed |= local_changed;
}
}
return changed;
}

} // namespace xla
69 changes: 69 additions & 0 deletions xla/service/gpu/transforms/collective_select_folder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

namespace xla {

// When collective-permute operates on a comparison to a device id
// and the senders match the condition's branch
// we can link collective-permute to the original data skipping the comparison.
// For example
// condition = broadcast(compare(replica_id, X), direction=EQ
// data_snd = select(condition, compare_true_data, compare_false_data)
// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}}
// can be transformed to
// rcv = collective-permute(compare_true_data), pairs={{X,0}}
//
// The pass is *only* handling compare direction={EQ,NE}.
// The pass handles Compare with and without preceding Broadcast.
//
// This pass is particularly useful in the pipeline parallelism generated module
// such as:
// fwd_data = ...
// bwd_data =
// is_first_device = ...
// is_last_device = ...
// data_snd = select(is_last_device, bwd_data, fwd_data)
// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}}
// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}}
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv)
//
// After the transformation, the module will become:
// fwd_data_snd = ...
// bwd_data_snd = ...
// is_first_device = ...
// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}}
// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}}
// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv)
class CollectiveSelectFolder : public HloModulePass {
public:
absl::string_view name() const override { return "collective-select-folder"; }

absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
};

} // namespace xla

#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SELECT_FOLDER_H_
Loading

0 comments on commit 67b9ac3

Please sign in to comment.