Skip to content

Commit

Permalink
Fix a bug in OverestimateReshardingCost, which failed to take into ac…
Browse files Browse the repository at this point in the history
…count the fact that a replicated strategy is also a "tile maximal" strategy.

PiperOrigin-RevId: 681155227
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent 29cad9d commit 66e30bc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions xla/hlo/experimental/auto_sharding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/service/spmd:spmd_partitioner",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:btree_set",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
Expand Down
16 changes: 8 additions & 8 deletions xla/hlo/experimental/auto_sharding/cluster_environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
#include "xla/hlo/ir/hlo_sharding.h"
Expand Down Expand Up @@ -162,7 +163,7 @@ double ClusterEnvironment::ReshardingCostMixedMeshShape(
}
if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i],
(*src_tensor_dim_to_mesh_axis)[i])) {
// do nothing; the src is sharded more than the dest
// do nothing; the dst is sharded more than the src
continue;
}
if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i],
Expand Down Expand Up @@ -231,17 +232,16 @@ double ClusterEnvironment::CollectivePermuteCost(
// Overestimate the cost of replicating a tensor by decomposing the resharding
// operation as an all-gather on all mesh dimensions.
double ClusterEnvironment::OverestimateReplicationCost(
const Shape& shape, const HloSharding& src_spec,
const Shape& shape, const HloSharding& src_sharding,
const DeviceMesh& device_mesh) const {
if (src_spec.IsTileMaximal() || src_spec.IsManual()) {
// TODO(b/238210866) Do not use kInfinityCost.
return kInfinityCost;
if (src_sharding.IsReplicated()) {
return 0;
}
int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_spec);
int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_sharding);
double cost = 0.0;
for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) {
auto this_cost = this->AllGatherCost(bytes_moved, i);
cost += this_cost;
cost += src_sharding.IsTileMaximal() ? this->AllReduceCost(bytes_moved, i)
: this->AllGatherCost(bytes_moved, i);
bytes_moved *= device_mesh.dimensions()[i];
}
return cost;
Expand Down

0 comments on commit 66e30bc

Please sign in to comment.