Skip to content

Commit

Permalink
[XLA:SPMD] Propagate shardings backward along explicit batch dims in …
Browse files Browse the repository at this point in the history
…gather/scatter instructions.

We modify `GatherOperandShardingFromOutputParallelDimensions` and `ScatterUpdateShardingFromOutputParallelDimensions` to propagate shardings along the explicit batch dims in the backward direction (result -> operands).

PiperOrigin-RevId: 681538531
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Oct 2, 2024
1 parent 29cad9d commit dfb99ba
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 184 deletions.
320 changes: 136 additions & 184 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,80 @@ bool ContainsTileSharding(const HloModule& module) {
return false;
}

template <typename T>
std::vector<int64_t> argsort(absl::Span<const T> data) {
std::vector<int64_t> indices(data.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(),
[&data](int64_t i, int64_t j) { return data[i] < data[j]; });
return indices;
}

// Given a `source_sharding`, preserve the tiles along the `source_dims` and
// replicate the rest. The `target_dims` are used to determine the order of the
// dimensions in the resulting sharding. If `source_dims` and `target_dims` are
// in the different order (i.e., different argsort results), we need to
// transpose the tile assignment.
//
// Given the following input,
// * source_sharding = {devices=[2,3,5,7,11]<=[2310]}
// * source_dims = [2, 4, 1]
// * target_dims = [2, 1, 3]
// * target_shape_rank = 5
// The result shoule be {devices=[1,11,5,3,1,14]<=[2,3,5,7,11]T(4,2,1,0,3)
// last_tile_dim_replicate}.
HloSharding PropagateShardingAlongDimsAndReplicateOthers(
const HloSharding& source_sharding, absl::Span<const int64_t> source_dims,
absl::Span<const int64_t> target_dims, int64_t target_shape_rank) {
CHECK_EQ(source_dims.size(), target_dims.size());
if (source_sharding.IsTileMaximal() || source_sharding.IsManual()) {
return source_sharding;
}

HloSharding replicate_other_dims =
PartiallyReplicateTiledShardingOnAllDimsExcept(source_sharding,
source_dims);
if (replicate_other_dims.IsTileMaximal()) {
return replicate_other_dims;
}

std::vector<int64_t> argsort_source_dims = argsort(source_dims);
std::vector<int64_t> argsort_target_dims = argsort(target_dims);
if (argsort_source_dims != argsort_target_dims) {
std::vector<int64_t> perm(
replicate_other_dims.tile_assignment().num_dimensions(), -1);
for (int64_t i = 0; i < source_dims.size(); ++i) {
perm[source_dims[argsort_target_dims[i]]] = i;
}
int64_t i = source_dims.size();
for (int64_t& perm_element : perm) {
if (perm_element == -1) {
perm_element = i++;
}
}
replicate_other_dims = TransposeSharding(replicate_other_dims, perm);
}

std::vector<int64_t> target_tile_dims(target_shape_rank, 1);
for (int i = 0; i < source_dims.size(); ++i) {
target_tile_dims[target_dims[i]] =
source_sharding.tile_assignment().dim(source_dims[i]);
}
for (int64_t i = replicate_other_dims.TiledDataRank();
i < replicate_other_dims.tile_assignment().num_dimensions(); ++i) {
target_tile_dims.push_back(replicate_other_dims.tile_assignment().dim(i));
}

auto target_tile_assignment =
replicate_other_dims.tile_assignment().Reshape(target_tile_dims);
return replicate_other_dims.ReplicateOnLastTileDim()
? HloSharding::PartialTile(target_tile_assignment,
replicate_other_dims.metadata())
: HloSharding::Subgroup(target_tile_assignment,
replicate_other_dims.subgroup_types(),
replicate_other_dims.metadata());
}

HloSharding GatherOutputShardingFromIndexIndexPassthroughDimensions(
const HloSharding& index_sharding, const HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kGather);
Expand Down Expand Up @@ -1559,71 +1633,37 @@ std::optional<HloSharding> GatherOperandShardingFromOutputParallelDimensions(
if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) {
return output_sharding;
}
auto parallel_dims = GetGatherParallelBatchDims(gather, call_graph);
if (parallel_dims) {
auto output_parallel_dims =
GetGatherParallelOutputDims(gather, *parallel_dims);
auto output_aligned_operand_parallel_dims =
parallel_dims->operand_parallel_dims;
const Shape gather_shape = gather.shape();
CHECK_EQ(output_parallel_dims.size(),
output_aligned_operand_parallel_dims.size());
DimensionVector operand_tile_assignment(gather.operand(0)->shape().rank(),
1);
DimensionVector relevant_output_dims;
for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) {
if (parallel_idx >= output_parallel_dims.size() ||
output_parallel_dims[parallel_idx] != i) {
continue;
}
const int64_t operand_dim =
output_aligned_operand_parallel_dims[parallel_idx++];
operand_tile_assignment[operand_dim] =
output_sharding.tile_assignment().dim(i);
relevant_output_dims.push_back(i);
}
HloSharding relevant_output_sharding =
PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
relevant_output_dims);
if (relevant_output_sharding.IsTileMaximal()) {
return std::move(relevant_output_sharding);
}

for (int64_t i = relevant_output_sharding.TiledDataRank();
i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) {
operand_tile_assignment.push_back(
relevant_output_sharding.tile_assignment().dim(i));
}
auto tile_assignment = relevant_output_sharding.tile_assignment().Reshape(
operand_tile_assignment);
return relevant_output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment,
output_sharding.metadata())
: HloSharding::Subgroup(
tile_assignment, relevant_output_sharding.subgroup_types(),
output_sharding.metadata());

GatherScatterParallelDims parallel_dims;

const GatherDimensionNumbers& dnums = gather.gather_dimension_numbers();
if (!dnums.operand_batching_dims().empty()) {
parallel_dims.operand_parallel_dims.assign(
dnums.operand_batching_dims().begin(),
dnums.operand_batching_dims().end());
parallel_dims.indices_parallel_dims.assign(
dnums.start_indices_batching_dims().begin(),
dnums.start_indices_batching_dims().end());
}
if (std::optional<GatherScatterParallelDims> implicit_parallel_dims =
GetGatherParallelBatchDims(gather, call_graph)) {
parallel_dims.operand_parallel_dims.insert(
parallel_dims.operand_parallel_dims.end(),
implicit_parallel_dims->operand_parallel_dims.begin(),
implicit_parallel_dims->operand_parallel_dims.end());
parallel_dims.indices_parallel_dims.insert(
parallel_dims.indices_parallel_dims.end(),
implicit_parallel_dims->indices_parallel_dims.begin(),
implicit_parallel_dims->indices_parallel_dims.end());
}
return std::nullopt;
}

// Reorders `to_align` based on the order of how `target_permuted` is reordered
// from `target`, expecting the container size to be small.
absl::InlinedVector<int64_t, 1> AlignSmallContainers(
absl::Span<const int64_t> to_align, absl::Span<const int64_t> target,
absl::Span<const int64_t> target_permuted) {
CHECK(absl::c_is_permutation(target_permuted, target));
CHECK_EQ(to_align.size(), target.size());
absl::InlinedVector<int64_t, 1> to_align_permuted(to_align.size());
for (auto i = 0; i < target.size(); ++i) {
// This is small so just look linearly.
for (auto j = 0; j < target_permuted.size(); ++j) {
if (target_permuted[j] == target[i]) {
to_align_permuted[j] = to_align[i];
break;
}
}
if (parallel_dims.operand_parallel_dims.empty()) {
return std::nullopt;
}
return to_align_permuted;

return PropagateShardingAlongDimsAndReplicateOthers(
output_sharding, GetGatherParallelOutputDims(gather, parallel_dims),
parallel_dims.operand_parallel_dims, gather.operand(0)->shape().rank());
}

} // namespace
Expand Down Expand Up @@ -1776,58 +1816,37 @@ std::optional<HloSharding> ScatterUpdateShardingFromOutputParallelDimensions(
if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) {
return output_sharding;
}
auto parallel_dims = GetScatterParallelBatchDims(scatter, call_graph);
if (parallel_dims) {
auto update_parallel_dims =
GetScatterParallelUpdateDims(scatter, *parallel_dims);
auto index_aligned_operand_parallel_dims =
parallel_dims->operand_parallel_dims;
auto operand_parallel_dims_sorted = index_aligned_operand_parallel_dims;
absl::c_sort(operand_parallel_dims_sorted);
auto operand_aligned_update_parallel_dims = AlignSmallContainers(
update_parallel_dims, index_aligned_operand_parallel_dims,
operand_parallel_dims_sorted);
const Shape scatter_shape = scatter.shape().IsTuple()
? scatter.shape().tuple_shapes()[0]
: scatter.shape();
CHECK_EQ(update_parallel_dims.size(),
index_aligned_operand_parallel_dims.size());
DimensionVector update_tile_assignment(
scatter.scatter_updates()[0]->shape().rank(), 1);
DimensionVector relevant_output_dims;
for (int i = 0, parallel_idx = 0; i < scatter_shape.rank(); ++i) {
if (parallel_idx >= operand_parallel_dims_sorted.size() ||
operand_parallel_dims_sorted[parallel_idx] != i) {
continue;
}
const int64_t update_dim =
operand_aligned_update_parallel_dims[parallel_idx++];
update_tile_assignment[update_dim] =
output_sharding.tile_assignment().dim(i);
relevant_output_dims.push_back(i);
}
HloSharding relevant_output_sharding =
PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding,
relevant_output_dims);
if (relevant_output_sharding.IsTileMaximal()) {
return std::move(relevant_output_sharding);
}

for (int64_t i = relevant_output_sharding.TiledDataRank();
i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) {
update_tile_assignment.push_back(
relevant_output_sharding.tile_assignment().dim(i));
}
auto tile_assignment = relevant_output_sharding.tile_assignment().Reshape(
update_tile_assignment);
return relevant_output_sharding.ReplicateOnLastTileDim()
? HloSharding::PartialTile(tile_assignment,
output_sharding.metadata())
: HloSharding::Subgroup(
tile_assignment, relevant_output_sharding.subgroup_types(),
output_sharding.metadata());

GatherScatterParallelDims parallel_dims;

const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers();
if (!dnums.input_batching_dims().empty()) {
parallel_dims.operand_parallel_dims.assign(
dnums.input_batching_dims().begin(), dnums.input_batching_dims().end());
parallel_dims.indices_parallel_dims.assign(
dnums.scatter_indices_batching_dims().begin(),
dnums.scatter_indices_batching_dims().end());
}
if (std::optional<GatherScatterParallelDims> implicit_parallel_dims =
GetScatterParallelBatchDims(scatter, call_graph)) {
parallel_dims.operand_parallel_dims.insert(
parallel_dims.operand_parallel_dims.end(),
implicit_parallel_dims->operand_parallel_dims.begin(),
implicit_parallel_dims->operand_parallel_dims.end());
parallel_dims.indices_parallel_dims.insert(
parallel_dims.indices_parallel_dims.end(),
implicit_parallel_dims->indices_parallel_dims.begin(),
implicit_parallel_dims->indices_parallel_dims.end());
}

if (parallel_dims.operand_parallel_dims.empty()) {
return std::nullopt;
}
return std::nullopt;

return PropagateShardingAlongDimsAndReplicateOthers(
output_sharding, parallel_dims.operand_parallel_dims,
GetScatterParallelUpdateDims(scatter, parallel_dims),
scatter.scatter_updates()[0]->shape().rank());
}

HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions(
Expand Down Expand Up @@ -2384,6 +2403,7 @@ GetGatherOutputOrScatterUpdateParallelDims(
++idx_dim;
}
}
CHECK_EQ(output_parallel_dims.size(), indices_parallel_dims.size());
return output_parallel_dims;
}

Expand Down Expand Up @@ -2478,82 +2498,14 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims(
return passthrough_dims;
}

template <typename T>
std::vector<int64_t> argsort(absl::Span<const T> data) {
std::vector<int64_t> indices(data.size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(),
[&data](int64_t i1, int64_t i2) { return data[i1] < data[i2]; });
return indices;
}

HloSharding InferGatherScatterParallelShardingFromOperandSharding(
const HloSharding& operand_sharding, const Shape& operand_shape,
const Shape& shape,
absl::Span<const int64_t> output_aligned_operand_parallel_dims,
absl::Span<const int64_t> output_parallel_dims) {
if (operand_sharding.IsTileMaximal()) {
return operand_sharding;
}

HloSharding replicate_non_parallel_dims =
PartiallyReplicateTiledShardingOnAllDimsExcept(
operand_sharding, output_aligned_operand_parallel_dims);
if (replicate_non_parallel_dims.IsTileMaximal()) {
return replicate_non_parallel_dims;
}

// output_aligned_operand_parallel_dims and output_parallel_dims may not be
// in the same order. We need to transpose the sharding accordingly. For
// example, if output_aligned_operand_parallel_dims = [2, 4, 1] and
// output_parallel_dims = [2, 1, 3], the sharding needs to be transposed with
// perm = [3, 2, 1, 4, 0] to adjust the order of devices.
std::vector<int64_t> argsort_output_aligned_operand_parallel_dims =
argsort(output_aligned_operand_parallel_dims);
std::vector<int64_t> argsort_output_parallel_dims =
argsort(output_parallel_dims);
if (argsort_output_aligned_operand_parallel_dims !=
argsort_output_parallel_dims) {
std::vector<int64_t> perm(
replicate_non_parallel_dims.tile_assignment().num_dimensions(), -1);
for (int64_t i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) {
perm[output_aligned_operand_parallel_dims
[argsort_output_parallel_dims[i]]] = i;
}
int64_t i = output_aligned_operand_parallel_dims.size();
for (int64_t& perm_element : perm) {
if (perm_element == -1) {
perm_element = i++;
}
}
replicate_non_parallel_dims =
TransposeSharding(replicate_non_parallel_dims, perm);
}

// Collect tile dimensions in the operand.
std::vector<int64_t> output_tile_dims(shape.rank(), 1);
for (int i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) {
const int64_t operand_idx = output_aligned_operand_parallel_dims[i];
const int64_t output_idx = output_parallel_dims[i];
output_tile_dims[output_idx] =
operand_sharding.tile_assignment().dim(operand_idx);
}
for (int64_t i = replicate_non_parallel_dims.TiledDataRank();
i < replicate_non_parallel_dims.tile_assignment().num_dimensions();
++i) {
output_tile_dims.push_back(
replicate_non_parallel_dims.tile_assignment().dim(i));
}

auto output_tile_assignment =
replicate_non_parallel_dims.tile_assignment().Reshape(output_tile_dims);
return replicate_non_parallel_dims.ReplicateOnLastTileDim()
? HloSharding::PartialTile(output_tile_assignment,
replicate_non_parallel_dims.metadata())
: HloSharding::Subgroup(
output_tile_assignment,
replicate_non_parallel_dims.subgroup_types(),
replicate_non_parallel_dims.metadata());
return PropagateShardingAlongDimsAndReplicateOthers(
operand_sharding, output_aligned_operand_parallel_dims,
output_parallel_dims, shape.rank());
}

std::string GroupedSharding::ToString() const {
Expand Down
Loading

0 comments on commit dfb99ba

Please sign in to comment.