Skip to content

Commit

Permalink
[XLA:SPMD] Remove LookaheadUserSharding in sharding propagation.
Browse files Browse the repository at this point in the history
When we infer the dot sharding from its operands, it is possible that both operands can improve the dot sharding. LookaheadUserSharding iterates the dot users and decides which dot operand sharding is preferred. This cl removes it for two reasons.

1. It is unnecessary. If we can predict the sharding from dot users, we can wait the sharding to be propagated from users. The propagted sharding from users can still help us make choice between dot operands.
2. The lookhead sharding may be wrong. LookaheadUserSharding is a heuristics. We cannot guarantee that the predicted sharding will hold in the dot users.

Reverts b4ea979

PiperOrigin-RevId: 669066012
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Sep 18, 2024
1 parent b0e33bb commit 95b8e29
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,13 @@ std::optional<HloSharding> HandlerBase::GetShardingFromUser(
CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get()));
if (ins_->opcode() == HloOpcode::kConvolution) {
xla::InferConvolutionShardingFromOperands(
ins_clone.get(), call_graph_, 10,
/* may_combine_partial_sharding */ true, /* is_spmd */ true);
ins_clone.get(), /* aggressiveness */ 10,
/* may_combine_partial_sharding */ true);
} else {
xla::InferDotShardingFromOperands(
ins_clone.get(), call_graph_,
ins_clone.get(),
dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()),
/* may_combine_partial_sharding/ */ true, /* is_spmd */ true);
/* aggressiveness */ 10, /* may_combine_partial_sharding */ true);
}
if (!ins_clone->has_sharding()) {
return std::nullopt;
Expand Down
153 changes: 57 additions & 96 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License.
#include "xla/service/sharding_propagation.h"

#include <algorithm>
#include <array>
#include <cstdint>
#include <functional>
#include <iterator>
#include <list>
#include <map>
#include <memory>
#include <optional>
Expand All @@ -36,17 +36,20 @@ limitations under the License.
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/array.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_domain_metadata.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/ir/hlo_sharding_metadata.h"
#include "xla/hlo/utils/hlo_sharding_util.h"
#include "xla/protobuf_util.h"
#include "xla/service/call_graph.h"
#include "xla/service/dot_as_convolution_util.h"
#include "xla/service/host_memory_offload_annotations.h"
#include "xla/service/spmd/shard_barrier_partitioner.h"
Expand Down Expand Up @@ -416,55 +419,6 @@ bool SupportSpatialPartitioning(
}
}

// Helper to lookahead sharding of user of an instruction to be used as guidance
// for ambiguous cases.
std::optional<HloSharding> LookaheadUserSharding(HloInstruction* instr,
bool is_spmd,
const CallGraph& call_graph) {
if (instr->user_count() != 1) {
return std::nullopt;
}
HloInstruction* current_user = instr->users()[0];
std::optional<HloSharding> sharding;
std::vector<HloInstruction*> users_chain = {instr, current_user};
// Collect single user instructions along the way.
while (!current_user->has_sharding()) {
// Only consider single user chains.
if (current_user->users().size() != 1) {
users_chain.clear();
break;
}
current_user = current_user->users()[0];
users_chain.push_back(current_user);
}
// Early exit for unsupported cases.
if (users_chain.empty()) {
return std::nullopt;
}
for (int i = users_chain.size() - 1; i >= 1; --i) {
HloInstruction* user = users_chain[i];
HloInstruction* current = users_chain[i - 1];
CHECK(user->has_sharding());
sharding = ShardingPropagation::GetShardingFromUser(
*current, *user, INT64_MAX, is_spmd, call_graph,
/*sharding_helper=*/nullptr);
// We need to set the sharding to the instruction, because
// GetShardingFromUser() interface uses sharding from the instruction
// itself. It will be cleared out later.
if (sharding.has_value() && i != 1) {
current->set_sharding(*sharding);
continue;
}
break;
}
// Clear the sharding of the middle instructions we set the sharding of
// because they were unsharded.
for (int i = 1; i < users_chain.size() - 1; ++i) {
users_chain[i]->clear_sharding();
}
return sharding;
}

// Infer output sharding on index parallel dimensions for gather from operand
// and indices.
bool InferGatherParallelShardingFromOperands(
Expand Down Expand Up @@ -1071,9 +1025,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) {
} // namespace

bool InferDotShardingFromOperands(
HloInstruction* instruction, const CallGraph& call_graph,
HloInstruction* instruction,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
bool may_combine_partial_sharding, bool is_spmd) {
int64_t aggressiveness, bool may_combine_partial_sharding) {
auto from_operand = [&](int64_t operand_index) {
auto operand = instruction->operand(operand_index);
const HloSharding& operand_sharding = operand->sharding();
Expand Down Expand Up @@ -1128,55 +1082,66 @@ bool InferDotShardingFromOperands(
from_operand(1), instruction, may_combine_partial_sharding,
/*allow_aggressive_resharding=*/false);
}
// If not improved sharding found then do not set any sharding.

// Four cases based on if improved_operand_0 and improved_operand_1 are
// available.
// Case 0. Both operands have no improved sharding.
if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) {
return false;
}
// Sharding found from operand 0 but not operand 1. Set sharding from operand
// 0
// Case 1. Sharding found from operand 0 but not operand 1. Set sharding from
// operand 0.
if (improved_operand_0.has_value() && !improved_operand_1.has_value()) {
instruction->set_sharding(*improved_operand_0);
return true;
}
// Sharding found from operand 1 but not operand 0. Set sharding from operand
// 1
// Case 2. Sharding found from operand 1 but not operand 0. Set sharding from
// operand 1.
if (!improved_operand_0.has_value() && improved_operand_1.has_value()) {
instruction->set_sharding(*improved_operand_1);
return true;
}
// Case 3. Both operands have improved shardings.
CHECK(improved_operand_0.has_value() && improved_operand_1.has_value());
std::optional<HloSharding> lookahead_sharding =
LookaheadUserSharding(instruction, is_spmd, call_graph);

// If one of the improved shardings is a sub-tiling or equal to the other, use
// the better sharding with more tiles.
if (hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *improved_operand_0, *improved_operand_1)) {
instruction->set_sharding(*improved_operand_0);
return true;
}
if (hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *improved_operand_1, *improved_operand_0)) {
instruction->set_sharding(*improved_operand_1);
return true;
}

// If the two improved shardings are mergeable, there is no conflict.
if (std::optional<HloSharding> improved_sharding =
hlo_sharding_util::ReturnImprovedShardingImpl(
*improved_operand_0, &improved_operand_1.value(),
instruction->shape(), may_combine_partial_sharding,
/*allow_aggressive_resharding=*/false)) {
instruction->set_sharding(*improved_sharding);
return true;
}

if (aggressiveness < 3) {
// We can improve the dot with different shardings. Pause the propagation
// and wait for the winner between the two operands.
return false;
}

// The two improved sharding are different and we are at the highest
// aggressiveness. Prioritize the operand with larger size.
std::array<HloSharding, 2> sharding_priority = {*improved_operand_0,
*improved_operand_1};
bool priority_defined_with_lookahead = false;
// Found sharding from lookahead.
if (lookahead_sharding.has_value()) {
const bool operand_0_is_lookahead_subtiling =
hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *lookahead_sharding, *improved_operand_0);
const bool operand_1_is_lookahead_subtiling =
hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *lookahead_sharding, *improved_operand_1);
// If the sharding from operand 0 is a subtiling of the user, but not the
// one from operand 1 prioritize that sharding.
if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) {
priority_defined_with_lookahead = true;
}
// If the sharding from operand 1 is a subtiling of the user, but not the
// one from operand 0 prioritize that sharding.
if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) {
instruction->set_sharding(*improved_operand_1);
std::swap(sharding_priority[0], sharding_priority[1]);
priority_defined_with_lookahead = true;
}
}
// If lookahead didn't define a priority then use size.
if (!priority_defined_with_lookahead &&
ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
std::swap(sharding_priority[0], sharding_priority[1]);
}

// Set primary sharding to the instruction and then try to improve it with
// the secondary sharding.
instruction->set_sharding(sharding_priority[0]);
Expand All @@ -1187,10 +1152,8 @@ bool InferDotShardingFromOperands(

// Convolution handling for InferShardingFromOperands().
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
const CallGraph& call_graph,
int64_t aggressiveness,
bool may_combine_partial_sharding,
bool is_spmd) {
bool may_combine_partial_sharding) {
auto get_partitions_for_dims =
[&](const HloInstruction* inst,
absl::Span<
Expand Down Expand Up @@ -1225,8 +1188,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
(lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
instruction->batch_group_count() == 1 &&
instruction->feature_group_count() == 1)) {
return InferDotShardingFromOperands(instruction, call_graph, dot_dims,
may_combine_partial_sharding, is_spmd);
return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness,
may_combine_partial_sharding);
}
const auto& dnums = instruction->convolution_dimension_numbers();
const HloInstruction* lhs = instruction->operand(0);
Expand Down Expand Up @@ -2329,9 +2292,8 @@ bool ShardingPropagation::InferShardingFromOperands(
1);
}
case HloOpcode::kConvolution:
return InferConvolutionShardingFromOperands(
instruction, call_graph, aggressiveness, may_combine_partial_sharding,
is_spmd_);
return InferConvolutionShardingFromOperands(instruction, aggressiveness,
may_combine_partial_sharding);
case HloOpcode::kTranspose: {
const HloInstruction* input = instruction->operand(0);
if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) {
Expand Down Expand Up @@ -2420,9 +2382,8 @@ bool ShardingPropagation::InferShardingFromOperands(
case HloOpcode::kDot: {
const auto& dnums =
dot_as_convolution_util::ParseDotGeneralFromDot(instruction);
return InferDotShardingFromOperands(instruction, call_graph, dnums,
may_combine_partial_sharding,
is_spmd_);
return InferDotShardingFromOperands(instruction, dnums, aggressiveness,
may_combine_partial_sharding);
}
case HloOpcode::kParameter: {
auto parent_it = computation_map.find(instruction->parent());
Expand Down
9 changes: 4 additions & 5 deletions xla/service/sharding_propagation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_
#define XLA_SERVICE_SHARDING_PROPAGATION_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
Expand All @@ -35,17 +36,15 @@ namespace xla {
// Infers the shardings for a dot HLO op from the shardings on its operands,
// which are expected to have sharding annotations.
bool InferDotShardingFromOperands(
HloInstruction* instruction, const CallGraph& call_graph,
HloInstruction* instruction,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
bool may_combine_partial_sharding, bool is_spmd);
int64_t aggressiveness, bool may_combine_partial_sharding);

// Infers the shardings for a convolution HLO op from the shardings on its
// operands, which are expected to have sharding annotations.
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
const CallGraph& call_graph,
int64_t aggressiveness,
bool may_combine_partial_sharding,
bool is_spmd);
bool may_combine_partial_sharding);

// Remove Sharding custom-call instruction by folding the sharding attribute
// to its operand. If the operand already has a different sharding, insert a
Expand Down
Loading

0 comments on commit 95b8e29

Please sign in to comment.