Skip to content

Commit

Permalink
This CL correctly computes the tensor dim to mesh axis mapping for mi…
Browse files Browse the repository at this point in the history
…xed mesh strategies when computing resharding costs involving such a strategy.

PiperOrigin-RevId: 681513792
  • Loading branch information
Google-ML-Automation committed Oct 2, 2024
1 parent 93be085 commit 592e214
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 27 deletions.
4 changes: 3 additions & 1 deletion xla/hlo/experimental/auto_sharding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,12 @@ cc_library(
":auto_sharding_strategy",
":auto_sharding_util",
":profiling_result",
"//xla:array",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/service/spmd:spmd_partitioner",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
Expand Down
8 changes: 7 additions & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DAllOptions) {
option.device_mesh_ids = {0, 1, 2, 3};
option.device_mesh_alpha = {1.0, 1.0};
option.device_mesh_beta = {0.01, 1.0};
option.allow_mixed_mesh_shape = false;
RunMatMulAutoShardingWithOptions(option, 4, 2);

option.enable = true;
Expand All @@ -288,6 +289,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBeta) {
option.enable = true;
option.device_mesh_shape = {2, 2};
option.device_mesh_ids = {0, 1, 2, 3};
option.allow_mixed_mesh_shape = false;
RunMatMulAutoShardingWithOptions(option, 4, 2);

option.enable = true;
Expand All @@ -304,6 +306,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBetaMeshIds) {
AutoShardingOption option;
option.enable = true;
option.device_mesh_shape = {2, 2};
option.allow_mixed_mesh_shape = false;
RunMatMulAutoShardingWithOptions(option, 4, 2);

option.enable = true;
Expand All @@ -322,6 +325,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoMeshIds) {
option.device_mesh_shape = {2, 2};
option.device_mesh_alpha = {1.0, 1.0};
option.device_mesh_beta = {0.01, 1.0};
option.allow_mixed_mesh_shape = false;
RunMatMulAutoShardingWithOptions(option, 4, 2);

option.enable = true;
Expand Down Expand Up @@ -349,6 +353,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape3DAllOptions) {
TEST_F(AutoShardingTest, Matmul3DMeshShape2DSharding) {
AutoShardingOption option;
option.enable = true;
option.allow_mixed_mesh_shape = false;
option.device_mesh_shape = {1, 2, 2};
RunMatMulAutoShardingWithOptions(option, 4, 2);

Expand Down Expand Up @@ -458,7 +463,7 @@ TEST_F(AutoShardingTest, LargeSize) {
option.device_mesh_alpha = {1.0, 1.0, 1.0, 1.0};
option.device_mesh_beta = {1.0, 1.0, 1.0, 1.0};
option.memory_budget_per_device = (8192 + 8192 * 2 + 8192 * 4 / 8);
RunMatMulAutoShardingWithOptions(option, 7, 1);
RunMatMulAutoShardingWithOptions(option, 56, 1);
}

TEST_F(AutoShardingTest, InvalidOptions) {
Expand Down Expand Up @@ -716,6 +721,7 @@ ENTRY %elementwise {
.enable = true,
.preserve_shardings =
AutoShardingOption::PreserveShardingsType::kKeepAllShardings,
.allow_mixed_mesh_shape = false,
.only_allow_divisible_input_output = false,
.device_mesh_shape = {16, 16},
.device_mesh_alpha = {1.0, 1.0},
Expand Down
49 changes: 49 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,55 @@ absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
return tensor_dim_to_device_dim;
}

absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
GetTensorDimToMeshDimMixedMeshSharding(int64_t tensor_shape_rank,
const HloSharding& sharding,
const DeviceMesh& device_mesh,
bool consider_reverse_device_meshes) {
CHECK(!sharding.IsReplicated());
// Check the compatibility of tensor_shape_rank and spec
if (tensor_shape_rank != sharding.TiledDataRank()) {
return absl::InvalidArgumentError(
"Tensor shape rank should be equal to the tiled data rank of the input "
"spec.");
}
if (!TileAssignmentMatchesMesh(sharding, device_mesh)) {
return absl::InvalidArgumentError(
"Device mesh and tile assignment need to have the same number of "
"sharded dims.");
}

TF_ASSIGN_OR_RETURN(
std::vector<int64_t> axes,
GetMeshDimPermutationOrderInShardingSpec(sharding, device_mesh,
consider_reverse_device_meshes));

std::vector<absl::btree_set<int64_t>> tensor_dim_to_mesh_axis_mapping;
int mesh_axis_idx = 0;
for (int i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) {
if (sharding.tile_assignment().dim(i) == 1) {
tensor_dim_to_mesh_axis_mapping.push_back({});
continue;
}

absl::btree_set<int64_t> mesh_axes_for_this_tensor_dim;
int product = 1;
do {
if (mesh_axis_idx >= device_mesh.num_dimensions()) {
return absl::InternalError(
"Mismatched mesh shapes encountered. This can happen when the "
"sharding does not map well to the mesh shape provided");
}
product *= device_mesh.dim(axes[mesh_axis_idx]);
mesh_axes_for_this_tensor_dim.insert(axes[mesh_axis_idx]);
mesh_axis_idx++;
} while (product < sharding.tile_assignment().dim(i));
CHECK(!mesh_axes_for_this_tensor_dim.empty());
tensor_dim_to_mesh_axis_mapping.push_back(mesh_axes_for_this_tensor_dim);
}
return tensor_dim_to_mesh_axis_mapping;
}

std::vector<int64_t> GetTensorDimToMeshDim(
int64_t tensor_shape_rank, const HloSharding& spec,
const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) {
Expand Down
10 changes: 10 additions & 0 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/functional/function_ref.h"
Expand Down Expand Up @@ -472,6 +473,15 @@ absl::StatusOr<int64_t> CheckArithmeticSequence(
// device mesh.
bool TileAssignmentMatchesMesh(const HloSharding& spec, const DeviceMesh& mesh);

absl::StatusOr<std::vector<int64_t>> GetMeshDimPermutationOrderInShardingSpec(
const HloSharding& spec, const Array<int64_t>& device_mesh,
bool consider_reverse_device_meshes);

absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
GetTensorDimToMeshDimMixedMeshSharding(
int64_t tensor_shape_rank, const HloSharding& sharding,
const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false);

// Get the mapped mesh dimension for every tensor dimension.
// The returned value maps ith tensor dim to one mesh dim. -1 means the tensor
// is replicated on that dimension.
Expand Down
92 changes: 70 additions & 22 deletions xla/hlo/experimental/auto_sharding/cluster_environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@ limitations under the License.
#include "xla/hlo/experimental/auto_sharding/cluster_environment.h"

#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h"
#include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/service/spmd/spmd_partitioner_util.h"
#include "xla/shape.h"

Expand Down Expand Up @@ -121,35 +126,79 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const {
return AllToAllCostUtil(num_bytes, mesh_dim, num_devices);
}

template <typename T>
bool IsSubset(absl::btree_set<T> superset, absl::btree_set<T> subset) {
for (const T& element : subset) {
if (!superset.contains(element)) {
return false;
}
}
return true;
}

// Do not consider device id changes yet.
double ClusterEnvironment::ReshardingCostMixedMeshShape(
const Shape& shape, absl::Span<const int64_t> src_tensor_dim_to_mesh_dim,
absl::Span<const int64_t> dst_tensor_dim_to_mesh_dim) const {
const Shape& shape, const HloSharding& src_sharding,
const HloSharding& dst_sharding) const {
absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
src_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding(
shape.rank(), src_sharding, device_mesh_,
/*consider_reverse_device_meshes=*/true);
absl::StatusOr<std::vector<absl::btree_set<int64_t>>>
dst_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding(
shape.rank(), dst_sharding, device_mesh_,
/*consider_reverse_device_meshes=*/true);
if (!src_tensor_dim_to_mesh_axis.ok() || !dst_tensor_dim_to_mesh_axis.ok()) {
return OverestimateReplicationCost(shape, src_sharding, device_mesh_);
}

int64_t num_devices = device_mesh_.num_elements();
double resharding_costs = 0.0;
std::vector<int64_t> collective_mesh_axes;
// Only consider sharded dimensions, do not consider replicate_on_last_dim.
for (size_t i = 0; i < shape.rank(); ++i) {
// Only consider sharded dimensions, do not consider replicate_on_last_dim.
if (src_tensor_dim_to_mesh_dim[i] == dst_tensor_dim_to_mesh_dim[i]) {
if ((*src_tensor_dim_to_mesh_axis)[i] ==
(*dst_tensor_dim_to_mesh_axis)[i]) {
continue;
}
if (dst_tensor_dim_to_mesh_dim[i] == -1 ||
src_tensor_dim_to_mesh_dim[i] == -1) {
// AllToAll cost
int64_t communication_dim;
if (dst_tensor_dim_to_mesh_dim[i] != -1) {
communication_dim = dst_tensor_dim_to_mesh_dim[i];
} else {
communication_dim = src_tensor_dim_to_mesh_dim[i];
}
int64_t communication_bytes = ByteSizeOfShape(shape);
resharding_costs +=
AllToAllCostUtil(communication_bytes, communication_dim, num_devices);
} else {
if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i],
(*src_tensor_dim_to_mesh_axis)[i])) {
// do nothing; the src is sharded more than the dest
continue;
}
if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i],
(*dst_tensor_dim_to_mesh_axis)[i])) {
// Do not support this sharding, assuming it is gonna be very expensive.
return kInfinityCost;
return OverestimateReplicationCost(shape, src_sharding, device_mesh_);
}
for (int64_t mesh_dim : (*src_tensor_dim_to_mesh_axis)[i]) {
if (!(*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) {
collective_mesh_axes.push_back(mesh_dim);
}
}
}
return resharding_costs;

auto is_mesh_axis_used_for_dst_sharding = [&](int64_t mesh_dim) {
int end = dst_sharding.ReplicateOnLastTileDim()
? dst_tensor_dim_to_mesh_axis->size() - 1
: dst_tensor_dim_to_mesh_axis->size();
for (int i = 0; i < end; ++i) {
if ((*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) {
return true;
}
}
return false;
};

double resharding_cost = 0.0;
int64_t communication_bytes = ByteSizeOfShape(shape);
for (int mesh_dim : collective_mesh_axes) {
bool used_for_dst_sharding = is_mesh_axis_used_for_dst_sharding(mesh_dim);
resharding_cost +=
used_for_dst_sharding
? AllToAllCostUtil(communication_bytes, mesh_dim, num_devices)
: AllGatherCost(communication_bytes, mesh_dim);
}
return resharding_cost;
}

double ClusterEnvironment::CollectivePermuteCost(
Expand Down Expand Up @@ -313,8 +362,7 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape,
dst_tensor_dim_to_mesh_dim_or.value();

if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) {
return ReshardingCostMixedMeshShape(shape, src_tensor_dim_to_mesh_dim,
dst_tensor_dim_to_mesh_dim);
return ReshardingCostMixedMeshShape(shape, src_spec, dst_spec);
}

AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim);
Expand Down
6 changes: 3 additions & 3 deletions xla/hlo/experimental/auto_sharding/cluster_environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ class ClusterEnvironment {

double AllToAllCost(double num_bytes, int mesh_dim) const;

double ReshardingCostMixedMeshShape(
const Shape& shape, absl::Span<const int64_t> src_tensor_dim_to_mesh_dim,
absl::Span<const int64_t> dst_tensor_dim_to_mesh_dim) const;
double ReshardingCostMixedMeshShape(const Shape& shape,
const HloSharding& src_sharding,
const HloSharding& dst_sharding) const;

double CollectivePermuteCost(
double num_bytes,
Expand Down

0 comments on commit 592e214

Please sign in to comment.