diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/xla/hlo/experimental/auto_sharding/cluster_environment.cc index b5ad371562e0c..9a68b636b79fa 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/types/span.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -162,7 +163,7 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape( } if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i], (*src_tensor_dim_to_mesh_axis)[i])) { - // do nothing; the src is sharded more than the dest + // do nothing; the dst is sharded more than the src continue; } if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i], @@ -231,17 +232,16 @@ double ClusterEnvironment::CollectivePermuteCost( // Overestimate the cost of replicating a tensor by decomposing the resharding // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( - const Shape& shape, const HloSharding& src_spec, + const Shape& shape, const HloSharding& src_sharding, const DeviceMesh& device_mesh) const { - if (src_spec.IsTileMaximal() || src_spec.IsManual()) { - // TODO(b/238210866) Do not use kInfinityCost. - return kInfinityCost; + if (src_sharding.IsReplicated()) { + return 0; } - int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_spec); + int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_sharding); double cost = 0.0; for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) { - auto this_cost = this->AllGatherCost(bytes_moved, i); - cost += this_cost; + cost += src_sharding.IsTileMaximal() ? this->AllReduceCost(bytes_moved, i) + : this->AllGatherCost(bytes_moved, i); bytes_moved *= device_mesh.dimensions()[i]; } return cost;