Skip to content

Commit

Permalink
Fix common ancestor in reduce-scatter & Fix replicated iota (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Sep 25, 2022
1 parent 0823879 commit 2aa7486
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 27 deletions.
9 changes: 6 additions & 3 deletions tensorflow/compiler/xla/service/spmd/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1297,9 +1297,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence,
strategy_map, strategies, false, " 1d");
}

// Replicate
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies,
replicated_penalty * 5);
if (strategies->leaf_vector.empty() || IsFollowedByBroadcast(ins) ||
batch_dim_map.count(ins) == 0) {
// Replicate
AddReplicatedStrategy(ins, cluster_env, strategy_map, strategies,
replicated_penalty * 5);
}

RemoveDuplicatedStrategy(strategies);
break;
Expand Down
61 changes: 37 additions & 24 deletions tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ bool IsBatchDimSwitchReshape(const HloInstruction* inst) {
return true;
}

// Return whether the instruction is followed by a broadcast.
bool IsFollowedByBroadcast(const HloInstruction* ins) {
const int max_depth = 6;
for (int i = 0; i < max_depth; ++i) {
if (ins->users().empty()) {
return false;
}
ins = PassThroughCustomCallMarkerUser(ins->users().front(), ins);
if (ins->opcode() == HloOpcode::kBroadcast) {
return true;
} else if (ins->opcode() == HloOpcode::kReshape) {
i--;
}
}

return false;
}

// Return whether the instruction is an activation from another pipeline stage.
bool IsActivationFromAnotherStage(const HloInstruction* ins,
const InstructionBatchDimMap& batch_dim_map) {
Expand Down Expand Up @@ -1339,44 +1357,39 @@ void TryReduceWithCommonAncestor(
absl::flat_hash_set<HloInstruction*>& boundary_set,
absl::flat_hash_set<HloInstruction*>& consumer_set,
const AliasMap& alias_map) {
if (boundary_set.size() != 2) {
return;
}
absl::flat_hash_map<const HloInstruction*, HloInstruction*> note_to_ancestor;
absl::flat_hash_map<const HloInstruction*, absl::flat_hash_set<HloInstruction*>> path;
absl::flat_hash_map<HloInstruction*, absl::flat_hash_set<HloInstruction*>> ancestor_to_node;

HloInstruction* ancestor = nullptr;
absl::flat_hash_set<HloInstruction*> path;
for (HloInstruction* node : boundary_set) {
HloInstruction* cur = node;
while (cur->operand_count() == 1) {
HloInstruction* operand =
PassThroughCustomCallMarkerOperand(cur->mutable_operand(0), cur);
if (replicated_set.count(operand)) {
path.insert(cur);
path[node].insert(cur);
}
cur = operand;
}
note_to_ancestor[node] = cur;
ancestor_to_node[cur].insert(node);
}

if (ancestor == nullptr) {
ancestor = cur;
} else {
if (ancestor != cur) {
// The nodes in boundary set do not have a common ancestor.
// This reduction fails.
return;
for (const auto& iter : ancestor_to_node) {
if (iter.second.size() > 1) {
// Find a common ancestor, reduce the boundary set
for (auto x: iter.second) {
boundary_set.erase(x);
for (auto y : path[x]) {
replicated_set.erase(y);
}
}
boundary_set.insert(iter.first);
consumer_set.insert(iter.first);
// Only allow one modification
return;
}
}
if (ancestor == nullptr) {
return;
}

// Find a common ancestor, reduce the boundary set
boundary_set.clear();
boundary_set.insert(ancestor);
for (auto x : path) {
replicated_set.erase(x);
}
consumer_set.insert(ancestor);
}

void UseAllReduceForGradAcc(
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/xla/service/spmd/auto_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ inline bool IsPassThroughTuple(const HloInstruction* inst) {
// of a dot.
bool IsBatchDimSwitchReshape(const HloInstruction* inst);

// Return whether the instruction is followed by a broadcast.
bool IsFollowedByBroadcast(const HloInstruction* inst);

// Return whether the instruction is an activation from another pipeline stage.
bool IsActivationFromAnotherStage(const HloInstruction* inst,
const InstructionBatchDimMap& batch_dim_map);
Expand Down

0 comments on commit 2aa7486

Please sign in to comment.