diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index ee63f4e96cc4f7..6a437e18a0353b 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -265,6 +265,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service/spmd:spmd_partitioner", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:btree_set", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/xla/hlo/experimental/auto_sharding/cluster_environment.cc index b5ad371562e0c8..9a68b636b79fa5 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/types/span.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -162,7 +163,7 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape( } if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i], (*src_tensor_dim_to_mesh_axis)[i])) { - // do nothing; the src is sharded more than the dest + // do nothing; the dst is sharded more than the src continue; } if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i], @@ -231,17 +232,16 @@ double ClusterEnvironment::CollectivePermuteCost( // Overestimate the cost of replicating a tensor by decomposing the resharding // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( - const Shape& shape, const HloSharding& src_spec, + const Shape& shape, const HloSharding& src_sharding, const DeviceMesh& device_mesh) const { - if (src_spec.IsTileMaximal() || src_spec.IsManual()) { - // TODO(b/238210866) Do not use kInfinityCost. - return kInfinityCost; + if (src_sharding.IsReplicated()) { + return 0; } - int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_spec); + int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_sharding); double cost = 0.0; for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) { - auto this_cost = this->AllGatherCost(bytes_moved, i); - cost += this_cost; + cost += src_sharding.IsTileMaximal() ? this->AllReduceCost(bytes_moved, i) + : this->AllGatherCost(bytes_moved, i); bytes_moved *= device_mesh.dimensions()[i]; } return cost;