Skip to content

Commit

Permalink
[XLA:SPMD] Fix scatter index-parallel partitioning issues.
Browse files Browse the repository at this point in the history
1. Fix gather/scatter partitioning where operand/updates sharding should be aligned with indices' sharding in index-parallel case.
2. Remove the assumption that gather/scatter index-parallel dim detection returns sorted dims where the dimension correspondence information is lost.

PiperOrigin-RevId: 675439880
  • Loading branch information
Tongfei-Guo authored and Google-ML-Automation committed Sep 17, 2024
1 parent cc69e67 commit 578f9a6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/transforms:hlo_constant_splitter",
"//xla/hlo/utils:hlo_live_range",
"//xla/hlo/utils:hlo_sharding_util",
Expand All @@ -62,7 +63,6 @@ cc_library(
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_dce",
"//xla/service:hlo_memory_scheduler",
"//xla/service:hlo_pass",
"//xla/service:hlo_value",
"//xla/service:optimize_input_output_buffer_alias",
"//xla/service:sharding_propagation",
Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/experimental/auto_sharding/auto_sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/hlo/utils/hlo_live_range.h"
#include "xla/service/call_graph.h"
#include "xla/service/hlo_alias_analysis.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/shape.h"

namespace xla {
Expand Down
20 changes: 11 additions & 9 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,6 @@ std::optional<GatherScatterParallelDims> GetGatherScatterBatchParallelDims(
index_parallel_in_dim[i] = -1;
}
}
absl::c_sort(indices_parallel_dims);
if (!indices_parallel_dims.empty()) {
return GatherScatterParallelDims{
indices_parallel_dims, operand_parallel_dims, index_parallel_in_dim};
Expand Down Expand Up @@ -2362,15 +2361,18 @@ GetGatherOutputOrScatterUpdateParallelDims(
int64_t index_vector_dim, absl::Span<const int64_t> offset_or_window_dims) {
absl::InlinedVector<int64_t, 1> output_parallel_dims;
auto indices_parallel_dims = parallel_dim.indices_parallel_dims;
for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) {
if (absl::c_linear_search(offset_or_window_dims, i)) {
continue;
}
const int index_dim = idx_dim < index_vector_dim ? idx_dim : idx_dim + 1;
if (absl::c_binary_search(indices_parallel_dims, index_dim)) {
output_parallel_dims.push_back(i);
for (int64_t indices_parallel_dim : indices_parallel_dims) {
for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) {
if (absl::c_linear_search(offset_or_window_dims, i)) {
continue;
}
const int index_dim = idx_dim < index_vector_dim ? idx_dim : idx_dim + 1;
if (indices_parallel_dim == index_dim) {
output_parallel_dims.push_back(i);
break;
}
++idx_dim;
}
++idx_dim;
}
return output_parallel_dims;
}
Expand Down
20 changes: 12 additions & 8 deletions xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,11 +684,13 @@ absl::StatusOr<HloInstruction*> PartitionGatherIndexParallelDimensions(
hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(),
indices_parallel_dims);
const GroupedSharding operand_grouped =
hlo_sharding_util::GroupShardingOnDims(operand.sharding(),
operand_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operand.sharding(), operand_parallel_dims),
new_indices_grouped);
const GroupedSharding output_grouped =
hlo_sharding_util::GroupShardingOnDims(gather_output_sharding,
output_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
gather_output_sharding, output_parallel_dims),
new_indices_grouped);
PartitionedHlo per_group_operand =
PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups);
PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo(
Expand Down Expand Up @@ -1130,11 +1132,13 @@ absl::StatusOr<HloInstruction*> PartitionScatterIndexParallelDimensions(
hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(),
indices_parallel_dims);
const GroupedSharding operand_grouped =
hlo_sharding_util::GroupShardingOnDims(operands[0].sharding(),
operand_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
operands[0].sharding(), operand_parallel_dims),
new_indices_grouped);
const GroupedSharding update_grouped =
hlo_sharding_util::GroupShardingOnDims(updates[0].sharding(),
update_parallel_dims);
AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(
updates[0].sharding(), update_parallel_dims),
new_indices_grouped);
const GroupedSharding& output_grouped = operand_grouped;
std::vector<PartitionedHlo> per_group_operands =
PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups);
Expand Down

0 comments on commit 578f9a6

Please sign in to comment.