diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 08c926c5268c0..0ebdd990e92cf 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -152,9 +152,6 @@ void GenerateScatterShardingFromOperands( if (it == scatter_shardings.end()) scatter_shardings.push_back(sharding); }; CHECK_EQ(scatter->scatter_operand_count(), 1); - const HloInstruction* scatter_data = scatter->scatter_operands()[0]; - const HloInstruction* scatter_indices = scatter->scatter_indices(); - const HloInstruction* scatter_update = scatter->scatter_updates()[0]; const HloSharding& indices_sharding = hlo_sharding_util:: ScatterIndexShardingFromUpdateIndexPassthroughDimensions(update_sharding, @@ -189,22 +186,21 @@ void GenerateScatterShardingFromOperands( const Shape& shape = scatter->shape(); scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( - data_sharding, scatter_data->shape(), shape, + data_sharding, shape, absl::MakeConstSpan(aligned_operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter indices sharding. scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( - indices_sharding, scatter_indices->shape(), shape, + indices_sharding, shape, absl::MakeConstSpan(scatter_parallel_dims->indices_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter update sharding. scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( - update_sharding, scatter_update->shape(), shape, - absl::MakeConstSpan(update_parallel_dims), + update_sharding, shape, absl::MakeConstSpan(update_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); for (const HloSharding& scatter_sharding : scatter_shardings) { @@ -448,7 +444,7 @@ BuildStrategyAndCost( if (hlo_sharding_util::IsSpatiallyPartitioned(data_spec)) { const HloSharding to_merge = hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - data_spec, data->shape(), gather_shape, + data_spec, gather_shape, absl::MakeConstSpan(aligned_operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims)); if (std::optional improved_spec = @@ -466,7 +462,7 @@ BuildStrategyAndCost( if (hlo_sharding_util::IsSpatiallyPartitioned(indices_spec)) { const HloSharding to_merge = hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - indices_spec, indices->shape(), gather_shape, + indices_spec, gather_shape, absl::MakeConstSpan( gather_parallel_dims->indices_parallel_dims), absl::MakeConstSpan(output_parallel_dims)); diff --git a/xla/hlo/utils/hlo_sharding_util.cc b/xla/hlo/utils/hlo_sharding_util.cc index e6248154036cb..f72ec2bbfc0c9 100644 --- a/xla/hlo/utils/hlo_sharding_util.cc +++ b/xla/hlo/utils/hlo_sharding_util.cc @@ -2499,8 +2499,7 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims( } HloSharding InferGatherScatterParallelShardingFromOperandSharding( - const HloSharding& operand_sharding, const Shape& operand_shape, - const Shape& shape, + const HloSharding& operand_sharding, const Shape& shape, absl::Span output_aligned_operand_parallel_dims, absl::Span output_parallel_dims) { return PropagateShardingAlongDimsAndReplicateOthers( diff --git a/xla/hlo/utils/hlo_sharding_util.h b/xla/hlo/utils/hlo_sharding_util.h index 67997dfaf8f20..3233fa1624549 100644 --- a/xla/hlo/utils/hlo_sharding_util.h +++ b/xla/hlo/utils/hlo_sharding_util.h @@ -358,8 +358,7 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims( // Infer output sharding on index parallel dimensions for gather/scatter from // gather operand/indices or scatter operands/indices/updates. HloSharding InferGatherScatterParallelShardingFromOperandSharding( - const HloSharding& operand_sharding, const Shape& operand_shape, - const Shape& shape, + const HloSharding& operand_sharding, const Shape& shape, absl::Span output_aligned_operand_parallel_dims, absl::Span output_parallel_dims); diff --git a/xla/service/sharding_propagation.cc b/xla/service/sharding_propagation.cc index 52e84f633a05d..b938bd5608bc2 100644 --- a/xla/service/sharding_propagation.cc +++ b/xla/service/sharding_propagation.cc @@ -480,8 +480,7 @@ bool InferGatherParallelShardingFromOperands( changed |= MaybeImproveInstructionSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - instruction->operand(0)->sharding(), - instruction->operand(0)->shape(), instruction->shape(), + instruction->operand(0)->sharding(), instruction->shape(), absl::MakeConstSpan(parallel_dims.operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims)), instruction, may_combine_partial_sharding); @@ -491,8 +490,7 @@ bool InferGatherParallelShardingFromOperands( changed |= MaybeImproveInstructionSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - instruction->operand(1)->sharding(), - instruction->operand(1)->shape(), instruction->shape(), + instruction->operand(1)->sharding(), instruction->shape(), absl::MakeConstSpan(parallel_dims.indices_parallel_dims), absl::MakeConstSpan(output_parallel_dims)), instruction, may_combine_partial_sharding); @@ -524,8 +522,7 @@ bool InferScatterParallelShardingFromOperands( changed |= MaybeImproveInstructionSubSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - scatter_operands[i]->sharding(), scatter_operands[i]->shape(), - shape, + scatter_operands[i]->sharding(), shape, absl::MakeConstSpan(parallel_dims.operand_parallel_dims), absl::MakeConstSpan(parallel_dims.operand_parallel_dims)), instruction, {i}, may_combine_partial_sharding); @@ -535,7 +532,7 @@ bool InferScatterParallelShardingFromOperands( if (hlo_sharding_util::IsSpatiallyPartitioned(scatter_indices)) { auto parallel_sharding_from_indices = hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - scatter_indices->sharding(), scatter_indices->shape(), shape, + scatter_indices->sharding(), shape, absl::MakeConstSpan(parallel_dims.indices_parallel_dims), absl::MakeConstSpan(parallel_dims.operand_parallel_dims)); for (int64_t i = 0; i != operand_count; ++i) { @@ -550,8 +547,8 @@ bool InferScatterParallelShardingFromOperands( changed |= MaybeImproveInstructionSubSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - scatter_updates[i]->sharding(), scatter_updates[i]->shape(), - shape, absl::MakeConstSpan(update_parallel_dims), + scatter_updates[i]->sharding(), shape, + absl::MakeConstSpan(update_parallel_dims), absl::MakeConstSpan(parallel_dims.operand_parallel_dims)), instruction, {i}, may_combine_partial_sharding); }