Skip to content

Commit

Permalink
1. Remove some unused code that employs extremely simple heuristics t…
Browse files Browse the repository at this point in the history
…o shard modules. Specifically, this removes AnnotateShardingWithSimpleHeuristic and its associated option. This code seems to have been used for the purposes of evaluating ALPA for the OSDI paper.

2. Also remove an unused parameter from a function.

PiperOrigin-RevId: 676190876
  • Loading branch information
Google-ML-Automation committed Sep 19, 2024
1 parent 9f30510 commit 6f06258
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 138 deletions.
124 changes: 1 addition & 123 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3229,122 +3229,6 @@ absl::Status GenerateReduceScatter(
return absl::OkStatus();
}

void AnnotateShardingWithSimpleHeuristic(
HloModule* module, const std::string& heuristic, const AliasMap& alias_map,
const ClusterEnvironment& cluster_env) {
const DeviceMesh& device_mesh = cluster_env.device_mesh_;
const DeviceMesh& device_mesh_1d = cluster_env.device_mesh_1d_;
int64_t num_devices = device_mesh.num_elements();

// Count the non-one mesh dimension.
size_t mesh_nn_dims = 0;
for (int dim : device_mesh.dimensions()) {
if (dim > 1) {
mesh_nn_dims++;
}
}

// Shard instructions
HloComputation* entry_computation = module->entry_computation();
for (HloInstruction* inst : entry_computation->instructions()) {
if (inst->opcode() == HloOpcode::kParameter) {
HloSharding output_spec = HloSharding::Replicate();
inst->set_sharding(output_spec);

if (heuristic == "shard-largest") {
std::vector<int64_t> lengths;
lengths.reserve(inst->shape().rank());
for (int64_t i = 0; i < inst->shape().rank(); ++i) {
lengths.push_back(inst->shape().dimensions(i));
}

std::vector<int> indices = Argsort(lengths);
int common_dims = std::min(mesh_nn_dims, indices.size());

if (common_dims < 1) {
continue;
}

if (common_dims == 1) {
int dim = indices[0];
int length = lengths[dim];
if (length % num_devices == 0) {
output_spec = Tile(inst->shape(), {dim}, {0}, device_mesh_1d);
}
} else {
int dim1 = indices[0];
int length1 = lengths[dim1];
int dim0 = indices[1];
int length0 = lengths[dim0];

if (length0 % device_mesh.dim(0) == 0 &&
length1 % device_mesh.dim(1) == 0) {
output_spec =
Tile(inst->shape(), {dim0, dim1}, {0, 1}, device_mesh);
}
}
} else if (heuristic == "shard-first") {
if (inst->shape().rank() > 0 &&
inst->shape().dimensions(0) % num_devices == 0) {
output_spec = Tile(inst->shape(), {0}, {0}, device_mesh_1d);
}
} else if (heuristic == "shard-last") {
int64_t last_dim = inst->shape().rank() - 1;
if (inst->shape().rank() > 0 &&
inst->shape().dimensions(last_dim) % num_devices == 0) {
output_spec = Tile(inst->shape(), {last_dim}, {0}, device_mesh_1d);
}
} else {
LOG(FATAL) << "Invalid heuristic: " << heuristic;
}

inst->set_sharding(output_spec);
} else if (inst->opcode() == HloOpcode::kDot) {
const HloInstruction* lhs = inst->operand(0);
const HloInstruction* rhs = inst->operand(1);
const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers();
// const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions();
// const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions();
tsl::protobuf::RepeatedField<int64_t> lhs_space_dims, rhs_space_dims;
std::tie(lhs_space_dims, rhs_space_dims) =
GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums);
}
}

// Meet the alias requirement for the output tuple.
HloInstruction* output = entry_computation->root_instruction();
const Shape& out_shape = output->shape();
ShapeTree<HloSharding> tuple_sharding(out_shape, HloSharding::Replicate());
std::vector<HloSharding> flattened_shardings;

std::function<void(HloInstruction*)> get_flattened_shardings;
get_flattened_shardings = [&](HloInstruction* cur) {
for (int64_t i = 0; i < cur->operand_count(); ++i) {
HloInstruction* operand = cur->mutable_operand(i);

if (operand->shape().IsTuple()) {
get_flattened_shardings(operand);
} else {
if (alias_map.contains(operand)) {
operand = alias_map.at(operand);
}
if (!operand->has_sharding()) {
operand->set_sharding(HloSharding::Replicate());
}
CHECK(operand->has_sharding());
flattened_shardings.push_back(operand->sharding());
}
}
};
get_flattened_shardings(output);
int i = 0;
for (auto& leaf : tuple_sharding.leaves()) {
leaf.second = flattened_shardings[i++];
}
CHECK_EQ(i, flattened_shardings.size());
output->set_sharding(HloSharding::Tuple(tuple_sharding));
}

// Filter strategies according to the option.force_batch_dim_to_mesh_dim.
// This can be used to forcibly generate data-parallel strategies.
absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape,
Expand Down Expand Up @@ -3978,12 +3862,6 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(
<< option_.memory_budget_per_device;
}

if (!option_.force_simple_heuristic.empty()) {
AnnotateShardingWithSimpleHeuristic(
module, option_.force_simple_heuristic, alias_map, cluster_env);
return AutoShardingResult::kModuleChangedShardingPerformed;
}

if (option_.force_batch_dim_to_mesh_dim >= 0) {
spmd::DisableIncompatibleMixedMeshShapeAndForceBatchDim(
batch_dim_map, sequence.instructions(), device_mesh.num_elements(),
Expand All @@ -3992,7 +3870,7 @@ absl::StatusOr<AutoShardingResult> AutoShardingImplementation::RunAutoSharding(

// ----- Analyze depth -----
spmd::InstructionDepthMap ins_depth_map;
ins_depth_map = spmd::BuildInstructionDepthMap(sequence, batch_dim_map);
ins_depth_map = spmd::BuildInstructionDepthMap(sequence);

// ----- Build strategies and costs -----
spmd::StrategyMap strategy_map;
Expand Down
5 changes: 0 additions & 5 deletions xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,6 @@ absl::Status HandleConv(std::unique_ptr<StrategyGroup>& strategy_group,
const AutoShardingOption& option,
const CallGraph& call_graph);

void AnnotateShardingWithSimpleHeuristic(HloModule* module,
const std::string& heuristic,
const AliasMap& alias_map,
const ClusterEnvironment& cluster_env);

// Handle alias: alias pairs must have the same HloSharding.
// To deal with alias, we do special process both before and after
// BuildStrategyAndCost. Because it is easier to handle elementwise
Expand Down
2 changes: 0 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_option.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ std::string AutoShardingOption::ToString() const {
absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape));
lines.push_back(absl::StrCat("solve_nd_sharding_iteratively: ",
solve_nd_sharding_iteratively));
lines.push_back(
absl::StrCat("force_simple_heuristic: ", force_simple_heuristic));
lines.push_back(absl::StrCat("force_strategy: ", force_strategy));

if (force_strategy) {
Expand Down
4 changes: 0 additions & 4 deletions xla/hlo/experimental/auto_sharding/auto_sharding_option.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ struct AutoShardingOption {
// strategies for N-D mesh shape.
bool solve_nd_sharding_iteratively = true;

// If it is not empty, forcibly use simple heuristic strategies
// instead of the ILP solver. This is used for ablation study.
std::string force_simple_heuristic;

// If true, forcibly set the strategy of some instructions.
bool force_strategy = false;
std::vector<int64_t> force_strategy_inst_indices;
Expand Down
3 changes: 1 addition & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ std::optional<HloSharding> PropagateReduceWindowSharding(
// We also assign a much larger distance to heavy operators (e.g., dot,
// convolution).
InstructionDepthMap BuildInstructionDepthMap(
const HloInstructionSequence& sequence,
const InstructionBatchDimMap& batch_dim_map) {
const HloInstructionSequence& sequence) {
const std::vector<HloInstruction*>& instructions = sequence.instructions();

InstructionDepthMap depth_map;
Expand Down
3 changes: 1 addition & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ std::optional<HloSharding> GetInputSharding(const HloInstruction* ins,
// instruction. We also assign a much larger distance to heavy operators (e.g.,
// dot, convolution).
InstructionDepthMap BuildInstructionDepthMap(
const HloInstructionSequence& sequence,
const InstructionBatchDimMap& batch_dim_map);
const HloInstructionSequence& sequence);

std::string GetBatchDimMapKey(const HloInstruction* ins, int64_t idx = -1);

Expand Down

0 comments on commit 6f06258

Please sign in to comment.