Skip to content

Commit

Permalink
[WIP] fixed coomputation graph visit for BuildStrategyAndCost
Browse files Browse the repository at this point in the history
  • Loading branch information
HeydrichBeillschmidt committed Aug 27, 2022
1 parent 181e8da commit 359f6a8
Show file tree
Hide file tree
Showing 2 changed files with 531 additions and 589 deletions.
61 changes: 53 additions & 8 deletions tensorflow/compiler/xla/service/spmd/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,8 +1274,11 @@ Status BuildStrategyAndCostForInstruction(

int operand_dim = dnums.offset_dims(i);

CHECK_LT(operand_dim, ins->operand(0)->shape().rank())
<< "Does not support this kind of Gather.";
CHECK_EQ(ins->shape().dimensions(operand_dim),
ins->operand(0)->shape().dimensions(operand_dim));
ins->operand(0)->shape().dimensions(operand_dim))
<< "Does not support this kind of Gather.";

std::vector<HloSharding> operand_specs{
Tile(ins->operand(0)->shape(), {operand_dim}, {j}, device_mesh),
Expand Down Expand Up @@ -1790,6 +1793,8 @@ Status BuildStrategyAndCostForInstruction(
strategies->childs.push_back(FollowInsStrategyVector(
src_strategies, operand->shape(), ins,
/* have_memory_cost= */ false, leaf_strategies));
RemoveIndivisibleStrategies(strategies->childs.back(),
operand->shape());
}
break;
}
Expand Down Expand Up @@ -1907,8 +1912,8 @@ Status BuildStrategyAndCostForInstruction(
std::vector<const StrategyVector*> candidate_strategies;
std::vector<HloInstruction*> instructions;
if (ins->branch_count()) {
for (size_t bid = 0; bid < ins->branch_computations().size(); ++bid) {
HloComputation* branch_computation = ins->branch_computations()[bid];
for (size_t bid = 0; bid < ins->branch_count(); ++bid) {
HloComputation* branch_computation = ins->branch_computation(bid);
instructions = branch_computation->MakeInstructionPostOrder();
for (size_t niid = 0; niid < instructions.size(); ++niid) {
const HloInstruction* nins = instructions[niid];
Expand Down Expand Up @@ -2697,6 +2702,47 @@ void DisableIncompatibleMixedMeshShapeAndForceBatchDim(
}
}

// Get the root computations of each subgraph of the module
HloInstructionSequence BuildEntrySequence(
const absl::flat_hash_map<const HloInstruction*, HloLiveRange::LogicalTime>&
instruction_schedule,
const absl::flat_hash_map<const HloComputation*, HloLiveRange::TimeBound>&
computation_span_times) {
absl::flat_hash_set<const HloComputation*> entries;
for (const auto p: computation_span_times) {
entries.insert(p.first);
}
for (const auto p: computation_span_times) {
for (const auto ins: p.first->instructions()) {
if (!instruction_schedule.count(ins)) continue;
if (ins->opcode()==HloOpcode::kWhile) {
entries.erase(ins->while_condition());
entries.erase(ins->while_body());
}
else if (ins->opcode()==HloOpcode::kConditional) {
if (ins->branch_count()) {
for (size_t bid = 0; bid < ins->branch_count(); ++bid) {
entries.erase(ins->branch_computation(bid));
}
}
else {
entries.erase(ins->true_computation());
entries.erase(ins->false_computation());
}
}
}
}

HloInstructionSequence entry_sequence;
for (auto c: entries) {
for (HloInstruction* ins: c->MakeInstructionPostOrder()) {
if (!instruction_schedule.count(ins)) continue;
entry_sequence.push_back(ins);
}
}
return entry_sequence;
}

StatusOr<bool> AutoSharding::Run(HloModule* module) {
if (!pass_context::GetBool("auto_sharding::enable", true)) {
return false;
Expand Down Expand Up @@ -2798,12 +2844,11 @@ StatusOr<bool> AutoSharding::Run(HloModule* module) {
// ----- Analyze the batch dim -----
const HloInstructionSequence& sequence =
hlo_live_range->flattened_instruction_sequence();
HloInstructionSequence entry_sequence;
for (HloInstruction* instruction : entry_computation->instructions()) {
entry_sequence.push_back(instruction);
}
HloInstructionSequence entry_sequence = BuildEntrySequence(
hlo_live_range->instruction_schedule(),
hlo_live_range->computation_span_times());
InstructionBatchDimMap batch_dim_map;
batch_dim_map = BuildInstructionBatchDimMap(entry_sequence);
batch_dim_map = BuildInstructionBatchDimMap(sequence);
if (solver_option.force_batch_dim_to_mesh_dim >= 0) {
DisableIncompatibleMixedMeshShapeAndForceBatchDim(
batch_dim_map, device_mesh.num_elements(), solver_option);
Expand Down
Loading

0 comments on commit 359f6a8

Please sign in to comment.