Skip to content

Commit

Permalink
Reland #17228
Browse files Browse the repository at this point in the history
A couple of changes from the original change:
  1. Don't use HloInstruction::operand_index() - this only returns the
     *first* occurence of an instruction in the operand sequence, thus if
     the same instruction is used in place of multiple orepards, we'll miss the
     subsequent ones.
  2. Handle propagating throught root instructions better. We originally
     only fixed up entry computation roots but we need should the same for
     any while/conditional root, otherwise inserting tokens in these types
     of roots is non-trivial. Simplify things by explicitly disjoining these
     instructions from being roots during canonicalization.

Reverts 1162b7e

PiperOrigin-RevId: 677963786
  • Loading branch information
vsytch authored and Google-ML-Automation committed Sep 24, 2024
1 parent 90693a5 commit a141e1d
Show file tree
Hide file tree
Showing 7 changed files with 1,228 additions and 6 deletions.
32 changes: 29 additions & 3 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2759,6 +2759,20 @@ int64_t HloInstruction::operand_index(const HloInstruction* target) const {
LOG(FATAL) << "target was not an operand: " << target->ToString();
}

std::vector<int64_t> HloInstruction::operand_indices(
const HloInstruction* target) const {
std::vector<int64_t> indices;
for (int64_t i = 0; i < operand_count(); ++i) {
if (target == operand(i)) {
indices.push_back(i);
}
}
if (indices.empty()) {
LOG(FATAL) << "target was not an operand: " << target->ToString();
}
return indices;
}

HloInstruction::InstructionVector HloInstruction::unique_operands() const {
InstructionVector unique;
absl::flat_hash_set<const HloInstruction*> seen;
Expand Down Expand Up @@ -3399,18 +3413,30 @@ const PtrVec<HloComputation*>& HloInstruction::branch_computations() const {
return called_computations();
}

int HloInstruction::branch_count() const {
int32_t HloInstruction::branch_count() const {
CHECK(HloOpcode::kConditional == opcode_);
return called_computations().size();
}

HloComputation* HloInstruction::branch_computation(int b) const {
CHECK(HloOpcode::kConditional == opcode_);
HloComputation* HloInstruction::branch_computation(int32_t b) const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
CHECK_GE(b, 0);
CHECK_LT(b, called_computations().size());
return called_computations()[b];
}

int32_t HloInstruction::branch_index(HloComputation* computation) const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
CHECK_NE(computation, nullptr);
for (int32_t idx = 0; idx < branch_count(); idx++) {
if (branch_computation(idx) == computation) {
return idx;
}
}
LOG(FATAL) << absl::StrFormat("Conditional %s does not contain branch %s",
name(), computation->name());
}

void HloInstruction::set_branch_computation(int b,
HloComputation* computation) {
CHECK_EQ(HloOpcode::kConditional, opcode_);
Expand Down
11 changes: 8 additions & 3 deletions xla/hlo/ir/hlo_instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1493,10 +1493,14 @@ class HloInstruction {
// within the operand vector.
InstructionVector unique_operands() const;

// Returns the index of 'target' in the operands sequence.
// Returns the first index of 'target' that occurs in the operands sequence.
// Precondition: target must be an operand (or a fatal error will occur).
int64_t operand_index(const HloInstruction* target) const;

// Returns all indices of 'target' that occur in the operands sequence.
// Precondition: target must be an operand (or a fatal error will occur).
std::vector<int64_t> operand_indices(const HloInstruction* target) const;

// Returns the number of users of this instruction.
int64_t user_count() const { return users_.size(); }

Expand Down Expand Up @@ -1808,8 +1812,9 @@ class HloInstruction {
//
// Precondition: The instruction is a Conditional instruction.
const PtrVec<HloComputation*>& branch_computations() const;
int branch_count() const;
HloComputation* branch_computation(int b) const;
int32_t branch_count() const;
HloComputation* branch_computation(int32_t b) const;
int32_t branch_index(HloComputation* computation) const;
// Sets a branch HloComputation for Conditional.
// The setter should only be called by HloModule or HloComputation methods.
//
Expand Down
35 changes: 35 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8544,4 +8544,39 @@ xla_cc_test(
],
)

cc_library(
name = "infeed_token_propagation",
srcs = ["infeed_token_propagation.cc"],
hdrs = ["infeed_token_propagation.h"],
deps = [
":hlo_dce",
":tuple_simplifier",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "infeed_token_propagation_test",
srcs = ["infeed_token_propagation_test.cc"],
deps = [
":infeed_token_propagation",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_matchers",
"//xla/tests:hlo_test_base",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"])
Loading

0 comments on commit a141e1d

Please sign in to comment.