Skip to content

Commit

Permalink
Clean-up. Remove unused argument from `InferGatherScatterParallelShar…
Browse files Browse the repository at this point in the history
…dingFromOperandSharding`.

PiperOrigin-RevId: 681578058
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Oct 2, 2024
1 parent 6fd8234 commit 1a1bfe0
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 22 deletions.
14 changes: 5 additions & 9 deletions xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ void GenerateScatterShardingFromOperands(
if (it == scatter_shardings.end()) scatter_shardings.push_back(sharding);
};
CHECK_EQ(scatter->scatter_operand_count(), 1);
const HloInstruction* scatter_data = scatter->scatter_operands()[0];
const HloInstruction* scatter_indices = scatter->scatter_indices();
const HloInstruction* scatter_update = scatter->scatter_updates()[0];

const HloSharding& indices_sharding = hlo_sharding_util::
ScatterIndexShardingFromUpdateIndexPassthroughDimensions(update_sharding,
Expand Down Expand Up @@ -189,22 +186,21 @@ void GenerateScatterShardingFromOperands(
const Shape& shape = scatter->shape();
scatter_shardings_insert(
hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding(
data_sharding, scatter_data->shape(), shape,
data_sharding, shape,
absl::MakeConstSpan(aligned_operand_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)));

// Infer output sharding from scatter indices sharding.
scatter_shardings_insert(
hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding(
indices_sharding, scatter_indices->shape(), shape,
indices_sharding, shape,
absl::MakeConstSpan(scatter_parallel_dims->indices_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)));

// Infer output sharding from scatter update sharding.
scatter_shardings_insert(
hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding(
update_sharding, scatter_update->shape(), shape,
absl::MakeConstSpan(update_parallel_dims),
update_sharding, shape, absl::MakeConstSpan(update_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)));

for (const HloSharding& scatter_sharding : scatter_shardings) {
Expand Down Expand Up @@ -448,7 +444,7 @@ BuildStrategyAndCost(
if (hlo_sharding_util::IsSpatiallyPartitioned(data_spec)) {
const HloSharding to_merge = hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
data_spec, data->shape(), gather_shape,
data_spec, gather_shape,
absl::MakeConstSpan(aligned_operand_parallel_dims),
absl::MakeConstSpan(output_parallel_dims));
if (std::optional<HloSharding> improved_spec =
Expand All @@ -466,7 +462,7 @@ BuildStrategyAndCost(
if (hlo_sharding_util::IsSpatiallyPartitioned(indices_spec)) {
const HloSharding to_merge = hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
indices_spec, indices->shape(), gather_shape,
indices_spec, gather_shape,
absl::MakeConstSpan(
gather_parallel_dims->indices_parallel_dims),
absl::MakeConstSpan(output_parallel_dims));
Expand Down
3 changes: 1 addition & 2 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2499,8 +2499,7 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims(
}

HloSharding InferGatherScatterParallelShardingFromOperandSharding(
const HloSharding& operand_sharding, const Shape& operand_shape,
const Shape& shape,
const HloSharding& operand_sharding, const Shape& shape,
absl::Span<const int64_t> output_aligned_operand_parallel_dims,
absl::Span<const int64_t> output_parallel_dims) {
return PropagateShardingAlongDimsAndReplicateOthers(
Expand Down
3 changes: 1 addition & 2 deletions xla/hlo/utils/hlo_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims(
// Infer output sharding on index parallel dimensions for gather/scatter from
// gather operand/indices or scatter operands/indices/updates.
HloSharding InferGatherScatterParallelShardingFromOperandSharding(
const HloSharding& operand_sharding, const Shape& operand_shape,
const Shape& shape,
const HloSharding& operand_sharding, const Shape& shape,
absl::Span<const int64_t> output_aligned_operand_parallel_dims,
absl::Span<const int64_t> output_parallel_dims);

Expand Down
15 changes: 6 additions & 9 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,7 @@ bool InferGatherParallelShardingFromOperands(
changed |= MaybeImproveInstructionSharding(
hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
instruction->operand(0)->sharding(),
instruction->operand(0)->shape(), instruction->shape(),
instruction->operand(0)->sharding(), instruction->shape(),
absl::MakeConstSpan(parallel_dims.operand_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)),
instruction, may_combine_partial_sharding);
Expand All @@ -491,8 +490,7 @@ bool InferGatherParallelShardingFromOperands(
changed |= MaybeImproveInstructionSharding(
hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
instruction->operand(1)->sharding(),
instruction->operand(1)->shape(), instruction->shape(),
instruction->operand(1)->sharding(), instruction->shape(),
absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
absl::MakeConstSpan(output_parallel_dims)),
instruction, may_combine_partial_sharding);
Expand Down Expand Up @@ -524,8 +522,7 @@ bool InferScatterParallelShardingFromOperands(
changed |= MaybeImproveInstructionSubSharding(
hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
scatter_operands[i]->sharding(), scatter_operands[i]->shape(),
shape,
scatter_operands[i]->sharding(), shape,
absl::MakeConstSpan(parallel_dims.operand_parallel_dims),
absl::MakeConstSpan(parallel_dims.operand_parallel_dims)),
instruction, {i}, may_combine_partial_sharding);
Expand All @@ -535,7 +532,7 @@ bool InferScatterParallelShardingFromOperands(
if (hlo_sharding_util::IsSpatiallyPartitioned(scatter_indices)) {
auto parallel_sharding_from_indices = hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
scatter_indices->sharding(), scatter_indices->shape(), shape,
scatter_indices->sharding(), shape,
absl::MakeConstSpan(parallel_dims.indices_parallel_dims),
absl::MakeConstSpan(parallel_dims.operand_parallel_dims));
for (int64_t i = 0; i != operand_count; ++i) {
Expand All @@ -550,8 +547,8 @@ bool InferScatterParallelShardingFromOperands(
changed |= MaybeImproveInstructionSubSharding(
hlo_sharding_util::
InferGatherScatterParallelShardingFromOperandSharding(
scatter_updates[i]->sharding(), scatter_updates[i]->shape(),
shape, absl::MakeConstSpan(update_parallel_dims),
scatter_updates[i]->sharding(), shape,
absl::MakeConstSpan(update_parallel_dims),
absl::MakeConstSpan(parallel_dims.operand_parallel_dims)),
instruction, {i}, may_combine_partial_sharding);
}
Expand Down

0 comments on commit 1a1bfe0

Please sign in to comment.