Skip to content

Commit

Permalink
Adjust cloning behavior to work properly for send + send-done pairs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675290538
  • Loading branch information
pschuh authored and Google-ML-Automation committed Sep 17, 2024
1 parent 8ace4ee commit c3f7f9b
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions xla/python/custom_partition_callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ absl::StatusOr<HloInstruction*> InlineHloComputation(
return it->second;
};

absl::flat_hash_map<int64_t, int64_t> channel_ids;
for (auto* inst : computation->MakeInstructionPostOrder()) {
if (inst->opcode() == HloOpcode::kParameter) {
replacements.emplace(inst, operands[inst->parameter_number()]);
Expand All @@ -80,9 +81,13 @@ absl::StatusOr<HloInstruction*> InlineHloComputation(
auto* new_inst = builder->AddInstruction(
inst->CloneWithNewOperands(inst->shape(), new_operands, &context));
HloChannelInstruction* channel_instr =
DynCast<HloChannelInstruction>(new_inst);
DynCast<HloChannelInstruction>(inst);
if (channel_instr && channel_instr->channel_id().has_value()) {
new_inst->set_channel_id(new_channel());
auto insert = channel_ids.emplace(*channel_instr->channel_id(), 0);
if (insert.second) {
insert.first->second = new_channel();
}
new_inst->set_channel_id(insert.first->second);
}
replacements.emplace(inst, new_inst);
}
Expand Down

0 comments on commit c3f7f9b

Please sign in to comment.