Skip to content

Commit

Permalink
Support shape transpose in hlo_sharding_util::ReshapeSharding.
Browse files Browse the repository at this point in the history
Before this cl, `hlo_sharding_util::ReshapeSharding` can handle the cases where source and target shapes can be transformed to each other by merging and splitting dimension sizes. It returns `std::nullopt` if transpose is needed between source and target shapes.

This cl extracts the gcd(source_sharding_tile_size, target_shape) when `source_shape % source_sharding_tile_size == 0` in the major dimensions. An example is shown below.
```
input_shape: [6, 4]
output_shape: [2, 2, 3, 2]
input_sharding: {devices=[6,1]<=[6]}
```
output_sharding is `{devices=[2,1,1,1,3]<=[6] last_tile_dim_replicate}`. Before this cl, the output_sharding is `{replicated}`.
PiperOrigin-RevId: 621333803
  • Loading branch information
ZixuanJiang authored and copybara-github committed Apr 5, 2024
1 parent be5c637 commit d41671e
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 42 deletions.
116 changes: 74 additions & 42 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <iterator>
#include <map>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <tuple>
Expand Down Expand Up @@ -690,30 +692,34 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
return sharding;
}

// In case of a tiled sharding the reshaped sharding will be a valid if the
// In case of a tiled sharding, the reshaped sharding will be valid if the
// reshape is composed from the following operations:
// * Adding or removing dimensions with size 1.
// * Merging consecutive dimensions where only the most major is sharded.
// * Splitting a dimension to consecutive dimensions.
// * Any reshaping of unsharded dimensions.
// Note that merge and split can happen consecutively on the same dimension,
// e.g., f32[1024,256,1024] to f32[128,2048,1024] can be considered that 1024
// gets split into 128 and 8, but 8 then gets merged with 256. We use stacks
// to make supporting such cases easy.
const Shape tile_shape = sharding.TileShape(source_shape);
//
// Merge and split can happen consecutively on the same dimension, e.g.,
// f32[1024,256] to f32[128,2048] can be considered that 1024 gets split into
// 128 and 8, but 8 then gets merged with 256. We use stacks to make
// supporting such cases easy.
//
// If transpose is needed between source and target shapes, we use the GCD of
// (target_shape_dim, sharding_dim) if source_shape_dim % sharding_dim == 0.
// For example, given the source_shape f32[6,4], target_shape f32[4,6] and
// sharding {devices=[6,1]<=[6]}, the output sharding is {devices=[2,1,3]<=[6]
// last_tile_dim_replicate}.
DimensionVector target_tile_assignment_dimensions;
DimensionVector source_dims_stack(source_shape.rank());
DimensionVector target_dims_stack(target_shape.rank());
DimensionVector sharding_tile_dims_stack(source_shape.rank());
int64_t added_to_partially_replicated = 1;
for (int64_t i = 0; i < source_shape.rank(); ++i) {
source_dims_stack[i] = source_shape.dimensions(source_shape.rank() - 1 - i);
sharding_tile_dims_stack[i] =
sharding.tile_assignment().dim(source_shape.rank() - 1 - i);
}
for (int64_t i = 0; i < target_shape.rank(); ++i) {
target_dims_stack[i] = target_shape.dimensions(target_shape.rank() - 1 - i);
}
DimensionVector source_dims_stack(source_shape.dimensions().rbegin(),
source_shape.dimensions().rend());
DimensionVector target_dims_stack(target_shape.dimensions().rbegin(),
target_shape.dimensions().rend());
DimensionVector sharding_tile_dims_stack(
sharding.tile_assignment().dimensions().begin(),
sharding.tile_assignment().dimensions().begin() + source_shape.rank());
std::reverse(sharding_tile_dims_stack.begin(),
sharding_tile_dims_stack.end());

bool inplace_add_sharding_dim = false;
auto append_sharding_dim = [&](int64_t size) {
if (inplace_add_sharding_dim) {
Expand All @@ -723,6 +729,7 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
}
inplace_add_sharding_dim = false;
};

while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
if (target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) != 1) {
Expand All @@ -731,15 +738,14 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
break;
}
int64_t s_size = 1;
int64_t t_size = 1;
int64_t s_partitions = 1;
if (!source_dims_stack.empty()) {
s_size = source_dims_stack.back();
source_dims_stack.pop_back();
s_partitions = sharding_tile_dims_stack.back();
sharding_tile_dims_stack.pop_back();
}
t_size = target_dims_stack.back();
int64_t t_size = target_dims_stack.back();
target_dims_stack.pop_back();
if (s_partitions * Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
Expand Down Expand Up @@ -767,15 +773,20 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
sharding_tile_dims_stack.push_back(s_partitions);
} else if (s_size == 1) {
// Trivial dimension removed.
if (s_partitions != 1) {
added_to_partially_replicated *= s_partitions;
}
target_dims_stack.push_back(t_size);
} else if (s_size > t_size) {
// Dimension split.
if (s_size % t_size != 0 || s_size % s_partitions != 0) {
if (s_size % s_partitions != 0) {
return std::nullopt;
}
if (s_size % t_size != 0) {
// Transpose is needed between source and target shapes.
auto gcd = std::gcd(t_size, s_partitions);
if (gcd > 1) {
append_sharding_dim(gcd);
}
break;
}
if (t_size % s_partitions == 0) {
append_sharding_dim(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
Expand All @@ -787,15 +798,25 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
source_dims_stack.push_back(s_size / t_size);
sharding_tile_dims_stack.push_back(s_partitions / t_size);
} else {
return std::nullopt;
break;
}
} else {
// Dimension merge. Also merge the source dimension with the next, and
// process it next time.
if (source_dims_stack.empty()) {
LOG(ERROR) << "source_dims_stack is empty";
}
if (s_size % s_partitions != 0) {
return std::nullopt;
}
CHECK(!source_dims_stack.empty());
if (t_size % s_size != 0) {
// Transpose is needed between source and target shapes.
auto gcd = std::gcd(t_size, s_partitions);
if (gcd > 1) {
append_sharding_dim(gcd);
}
break;
}
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
// If the next dimension to combine is sharded, we require that the
// current dimension's shard size to be 1. Otherwise, the new shard
Expand All @@ -810,31 +831,42 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
if (Product(target_tile_assignment_dimensions) == 1) {
return std::nullopt;
}
while (target_tile_assignment_dimensions.size() < target_shape.rank()) {
target_tile_assignment_dimensions.push_back(1);
}
for (int64_t i = sharding.TiledDataRank();
i < sharding.tile_assignment().num_dimensions(); ++i) {
target_tile_assignment_dimensions.push_back(
sharding.tile_assignment().dim(i));
i == sharding.SubgroupReplicationDim()
? 1
: sharding.tile_assignment().dim(i));
}

auto subgroup_types = sharding.subgroup_types();
// If we added dimensions to the partially replicated dimension then add the
// additional dimension on the partially replicated tiling.
if (added_to_partially_replicated > 1) {
if (sharding.ReplicateOnLastTileDim()) {
target_tile_assignment_dimensions.back() *= added_to_partially_replicated;
auto partially_replicated = std::div(
sharding.TotalNumTiles(), Product(target_tile_assignment_dimensions));
if (partially_replicated.rem != 0) {
LOG(ERROR) << "sharding: " << sharding.ToString()
<< "; target_tile_assignment_dimensions: "
<< absl::StrJoin(target_tile_assignment_dimensions, ",");
}
if (partially_replicated.quot > 1) {
if (sharding.HasPartialReplication()) {
target_tile_assignment_dimensions[sharding.SubgroupReplicationDim() -
sharding.TiledDataRank() +
target_shape.rank()] =
partially_replicated.quot;
} else {
target_tile_assignment_dimensions.push_back(
added_to_partially_replicated);
target_tile_assignment_dimensions.push_back(partially_replicated.quot);
}
// If subgroup_types doesn't have partially replicated as a sharding type
// then add it.
if (subgroup_types.empty() ||
subgroup_types.back() != OpSharding::REPLICATED) {
subgroup_types.push_back(OpSharding::REPLICATED);
}
}
// If subgroup_types doesn't have already partially replicated as a sharding
// type then add it.
if ((sharding.ReplicateOnLastTileDim() ||
added_to_partially_replicated > 1) &&
(subgroup_types.empty() ||
subgroup_types.back() != OpSharding::REPLICATED)) {
subgroup_types.push_back(OpSharding::REPLICATED);
}

auto new_tile_assignment =
sharding.tile_assignment().Reshape(target_tile_assignment_dimensions);
return HloSharding::Subgroup(new_tile_assignment, subgroup_types,
Expand Down
33 changes: 33 additions & 0 deletions xla/hlo/utils/hlo_sharding_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,39 @@ TEST(HloShardingUtilTest, ReshapeToTileDimension4D) {
}
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose1) {
Shape input_shape = ShapeUtil::MakeShape(F32, {6, 4});
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 2});
HloSharding input_sharding = HloSharding::IotaTile({6, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({2, 1, 1, 1, 3}));
HloSharding result = PropagateShardingThroughReshape(
input_shape, output_shape, input_sharding);
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose2) {
Shape input_shape = ShapeUtil::MakeShape(F32, {6, 4});
Shape output_shape = ShapeUtil::MakeShape(F32, {4, 6});
HloSharding input_sharding = HloSharding::IotaTile({6, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({2, 1, 3}));
HloSharding result = PropagateShardingThroughReshape(
input_shape, output_shape, input_sharding);
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTranspose3) {
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 6, 5});
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 5, 3});
HloSharding input_sharding = HloSharding::IotaTile({2, 6, 1});
HloSharding output_sharding =
HloSharding::PartialTile(TileAssignment({2, 1, 2, 1, 1, 3}));
HloSharding result = PropagateShardingThroughReshape(
input_shape, output_shape, input_sharding);
EXPECT_EQ(result, output_sharding);
}

TEST(HloShardingUtilTest, PropagateReshapeShardingTiledSplitPartialMatch) {
Shape input_shape = ShapeUtil::MakeShape(F32, {14, 16});
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 7, 4, 4});
Expand Down
46 changes: 46 additions & 0 deletions xla/service/sharding_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <ostream>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -1507,6 +1508,51 @@ ENTRY %reshape {
}
}

TEST_P(ParameterizedMetadataTest, ReshapeForwardPassTranspose1) {
const char* const hlo_string = R"(
HloModule module
ENTRY %reshape {
%param0 = f32[6,4,5] parameter(0), sharding={devices=[6,2,1]<=[12] metadata={op_name="a"}}
%reshape.1 = f32[2,3,20] reshape(%param0)
%reshape.2 = f32[2,4,3,5] reshape(%param0)
%reshape.3 = f32[20,6] reshape(%param0)
%reshape.4 = f32[3,5,8] reshape(%param0)
%reshape.5 = f32[10,4,3] reshape(%param0)
%reshape.6 = f32[5,8,3] reshape(%param0)
ROOT %tuple = tuple(%reshape.1, %reshape.2, %reshape.3, %reshape.4, %reshape.5, %reshape.6)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
if (GetParam().clear_metadata) {
ClearMetadata(module.get());
}
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(/*is_spmd=*/false, GetParam().propagate_metadata)
.Run(module.get()));
XLA_VLOG_LINES(1, module->ToString());
EXPECT_TRUE(changed);

std::vector<std::pair<std::string, std::string>> instruction_and_sharding = {
{"reshape.1", "{devices=[2,3,2]<=[12]}"},
{"reshape.2", "{devices=[2,1,1,1,6]<=[12] last_tile_dim_replicate}"},
{"reshape.3", "{devices=[2,1,6]<=[12] last_tile_dim_replicate}"},
{"reshape.4", "{devices=[3,1,1,4]<=[12] last_tile_dim_replicate}"},
{"reshape.5", "{devices=[2,1,1,6]<=[12] last_tile_dim_replicate}"},
{"reshape.6", "{replicated}"}};
for (const auto& [name, sharding] : instruction_and_sharding) {
auto* instruction = FindInstruction(module.get(), name);
ASSERT_NE(instruction, nullptr);
EXPECT_THAT(instruction, op::Sharding(sharding));
if (GetParam().propagate_metadata && !GetParam().clear_metadata) {
EXPECT_THAT(instruction->sharding(),
ShardingMetadata({CreateMetadata("a")}));
} else {
EXPECT_THAT(instruction->sharding(), ShardingMetadata({}));
}
}
}

TEST_P(ParameterizedMetadataTest, ReshapeBackwardPass) {
const char* const hlo_string = R"(
HloModule module
Expand Down

0 comments on commit d41671e

Please sign in to comment.