diff --git a/xla/hlo/experimental/auto_sharding/BUILD b/xla/hlo/experimental/auto_sharding/BUILD index ade737608da33..ee63f4e96cc4f 100644 --- a/xla/hlo/experimental/auto_sharding/BUILD +++ b/xla/hlo/experimental/auto_sharding/BUILD @@ -261,10 +261,12 @@ cc_library( ":auto_sharding_strategy", ":auto_sharding_util", ":profiling_result", - "//xla:array", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 55d53e34619a8..cd12aeaf3d846 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -272,6 +272,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DAllOptions) { option.device_mesh_ids = {0, 1, 2, 3}; option.device_mesh_alpha = {1.0, 1.0}; option.device_mesh_beta = {0.01, 1.0}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -288,6 +289,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBeta) { option.enable = true; option.device_mesh_shape = {2, 2}; option.device_mesh_ids = {0, 1, 2, 3}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -304,6 +306,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBetaMeshIds) { AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -322,6 +325,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoMeshIds) { option.device_mesh_shape = {2, 2}; option.device_mesh_alpha = {1.0, 1.0}; option.device_mesh_beta = {0.01, 1.0}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -349,6 +353,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape3DAllOptions) { TEST_F(AutoShardingTest, Matmul3DMeshShape2DSharding) { AutoShardingOption option; option.enable = true; + option.allow_mixed_mesh_shape = false; option.device_mesh_shape = {1, 2, 2}; RunMatMulAutoShardingWithOptions(option, 4, 2); @@ -458,7 +463,7 @@ TEST_F(AutoShardingTest, LargeSize) { option.device_mesh_alpha = {1.0, 1.0, 1.0, 1.0}; option.device_mesh_beta = {1.0, 1.0, 1.0, 1.0}; option.memory_budget_per_device = (8192 + 8192 * 2 + 8192 * 4 / 8); - RunMatMulAutoShardingWithOptions(option, 7, 1); + RunMatMulAutoShardingWithOptions(option, 56, 1); } TEST_F(AutoShardingTest, InvalidOptions) { @@ -716,6 +721,7 @@ ENTRY %elementwise { .enable = true, .preserve_shardings = AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .allow_mixed_mesh_shape = false, .only_allow_divisible_input_output = false, .device_mesh_shape = {16, 16}, .device_mesh_alpha = {1.0, 1.0}, diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 5ee4d464ff116..641f02ef2e918 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -1193,6 +1193,55 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( return tensor_dim_to_device_dim; } +absl::StatusOr>> +GetTensorDimToMeshDimMixedMeshSharding(int64_t tensor_shape_rank, + const HloSharding& sharding, + const DeviceMesh& device_mesh, + bool consider_reverse_device_meshes) { + CHECK(!sharding.IsReplicated()); + // Check the compatibility of tensor_shape_rank and spec + if (tensor_shape_rank != sharding.TiledDataRank()) { + return absl::InvalidArgumentError( + "Tensor shape rank should be equal to the tiled data rank of the input " + "spec."); + } + if (!TileAssignmentMatchesMesh(sharding, device_mesh)) { + return absl::InvalidArgumentError( + "Device mesh and tile assignment need to have the same number of " + "sharded dims."); + } + + TF_ASSIGN_OR_RETURN( + std::vector axes, + GetMeshDimPermutationOrderInShardingSpec(sharding, device_mesh, + consider_reverse_device_meshes)); + + std::vector> tensor_dim_to_mesh_axis_mapping; + int mesh_axis_idx = 0; + for (int i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) == 1) { + tensor_dim_to_mesh_axis_mapping.push_back({}); + continue; + } + + absl::btree_set mesh_axes_for_this_tensor_dim; + int product = 1; + do { + if (mesh_axis_idx >= device_mesh.num_dimensions()) { + return absl::InternalError( + "Mismatched mesh shapes encountered. This can happen when the " + "sharding does not map well to the mesh shape provided"); + } + product *= device_mesh.dim(axes[mesh_axis_idx]); + mesh_axes_for_this_tensor_dim.insert(axes[mesh_axis_idx]); + mesh_axis_idx++; + } while (product < sharding.tile_assignment().dim(i)); + CHECK(!mesh_axes_for_this_tensor_dim.empty()); + tensor_dim_to_mesh_axis_mapping.push_back(mesh_axes_for_this_tensor_dim); + } + return tensor_dim_to_mesh_axis_mapping; +} + std::vector GetTensorDimToMeshDim( int64_t tensor_shape_rank, const HloSharding& spec, const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) { diff --git a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 3b8dd44bd094e..0f758e8be7f5c 100644 --- a/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" @@ -472,6 +473,15 @@ absl::StatusOr CheckArithmeticSequence( // device mesh. bool TileAssignmentMatchesMesh(const HloSharding& spec, const DeviceMesh& mesh); +absl::StatusOr> GetMeshDimPermutationOrderInShardingSpec( + const HloSharding& spec, const Array& device_mesh, + bool consider_reverse_device_meshes); + +absl::StatusOr>> +GetTensorDimToMeshDimMixedMeshSharding( + int64_t tensor_shape_rank, const HloSharding& sharding, + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false); + // Get the mapped mesh dimension for every tensor dimension. // The returned value maps ith tensor dim to one mesh dim. -1 means the tensor // is replicated on that dimension. diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/xla/hlo/experimental/auto_sharding/cluster_environment.cc index f9c1ce429a114..b5ad371562e0c 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include +#include #include #include #include @@ -23,9 +24,13 @@ limitations under the License. #include #include +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/spmd/spmd_partitioner_util.h" #include "xla/shape.h" @@ -121,35 +126,79 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { return AllToAllCostUtil(num_bytes, mesh_dim, num_devices); } +template +bool IsSubset(absl::btree_set superset, absl::btree_set subset) { + for (const T& element : subset) { + if (!superset.contains(element)) { + return false; + } + } + return true; +} + // Do not consider device id changes yet. double ClusterEnvironment::ReshardingCostMixedMeshShape( - const Shape& shape, absl::Span src_tensor_dim_to_mesh_dim, - absl::Span dst_tensor_dim_to_mesh_dim) const { + const Shape& shape, const HloSharding& src_sharding, + const HloSharding& dst_sharding) const { + absl::StatusOr>> + src_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding( + shape.rank(), src_sharding, device_mesh_, + /*consider_reverse_device_meshes=*/true); + absl::StatusOr>> + dst_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding( + shape.rank(), dst_sharding, device_mesh_, + /*consider_reverse_device_meshes=*/true); + if (!src_tensor_dim_to_mesh_axis.ok() || !dst_tensor_dim_to_mesh_axis.ok()) { + return OverestimateReplicationCost(shape, src_sharding, device_mesh_); + } + int64_t num_devices = device_mesh_.num_elements(); - double resharding_costs = 0.0; + std::vector collective_mesh_axes; + // Only consider sharded dimensions, do not consider replicate_on_last_dim. for (size_t i = 0; i < shape.rank(); ++i) { - // Only consider sharded dimensions, do not consider replicate_on_last_dim. - if (src_tensor_dim_to_mesh_dim[i] == dst_tensor_dim_to_mesh_dim[i]) { + if ((*src_tensor_dim_to_mesh_axis)[i] == + (*dst_tensor_dim_to_mesh_axis)[i]) { continue; } - if (dst_tensor_dim_to_mesh_dim[i] == -1 || - src_tensor_dim_to_mesh_dim[i] == -1) { - // AllToAll cost - int64_t communication_dim; - if (dst_tensor_dim_to_mesh_dim[i] != -1) { - communication_dim = dst_tensor_dim_to_mesh_dim[i]; - } else { - communication_dim = src_tensor_dim_to_mesh_dim[i]; - } - int64_t communication_bytes = ByteSizeOfShape(shape); - resharding_costs += - AllToAllCostUtil(communication_bytes, communication_dim, num_devices); - } else { + if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i], + (*src_tensor_dim_to_mesh_axis)[i])) { + // do nothing; the src is sharded more than the dest + continue; + } + if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i], + (*dst_tensor_dim_to_mesh_axis)[i])) { // Do not support this sharding, assuming it is gonna be very expensive. - return kInfinityCost; + return OverestimateReplicationCost(shape, src_sharding, device_mesh_); + } + for (int64_t mesh_dim : (*src_tensor_dim_to_mesh_axis)[i]) { + if (!(*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) { + collective_mesh_axes.push_back(mesh_dim); + } } } - return resharding_costs; + + auto is_mesh_axis_used_for_dst_sharding = [&](int64_t mesh_dim) { + int end = dst_sharding.ReplicateOnLastTileDim() + ? dst_tensor_dim_to_mesh_axis->size() - 1 + : dst_tensor_dim_to_mesh_axis->size(); + for (int i = 0; i < end; ++i) { + if ((*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) { + return true; + } + } + return false; + }; + + double resharding_cost = 0.0; + int64_t communication_bytes = ByteSizeOfShape(shape); + for (int mesh_dim : collective_mesh_axes) { + bool used_for_dst_sharding = is_mesh_axis_used_for_dst_sharding(mesh_dim); + resharding_cost += + used_for_dst_sharding + ? AllToAllCostUtil(communication_bytes, mesh_dim, num_devices) + : AllGatherCost(communication_bytes, mesh_dim); + } + return resharding_cost; } double ClusterEnvironment::CollectivePermuteCost( @@ -313,8 +362,7 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape, dst_tensor_dim_to_mesh_dim_or.value(); if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { - return ReshardingCostMixedMeshShape(shape, src_tensor_dim_to_mesh_dim, - dst_tensor_dim_to_mesh_dim); + return ReshardingCostMixedMeshShape(shape, src_spec, dst_spec); } AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim); diff --git a/xla/hlo/experimental/auto_sharding/cluster_environment.h b/xla/hlo/experimental/auto_sharding/cluster_environment.h index d17b026dd8ffb..89b81133c95d0 100644 --- a/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -145,9 +145,9 @@ class ClusterEnvironment { double AllToAllCost(double num_bytes, int mesh_dim) const; - double ReshardingCostMixedMeshShape( - const Shape& shape, absl::Span src_tensor_dim_to_mesh_dim, - absl::Span dst_tensor_dim_to_mesh_dim) const; + double ReshardingCostMixedMeshShape(const Shape& shape, + const HloSharding& src_sharding, + const HloSharding& dst_sharding) const; double CollectivePermuteCost( double num_bytes,