From d391119197eab771a84c1f8a59a7f50b7da4b43d Mon Sep 17 00:00:00 2001 From: Juliana Franco Date: Thu, 19 Sep 2024 08:25:01 -0700 Subject: [PATCH] Skip axes of size 1 when building new TensorShardingAttr. Without this CL, the test added would result in a sharding including the axis of size 1 (i.e. "x") in the first dimension's sharding. PiperOrigin-RevId: 676423383 --- xla/service/spmd/shardy/mhlo_round_trip/BUILD | 20 +++++ .../shardy/mhlo_round_trip/mhlo_import.cc | 3 + .../mhlo_round_trip/mhlo_import_test.cc | 76 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc diff --git a/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/xla/service/spmd/shardy/mhlo_round_trip/BUILD index de973856c70da..ea6e91fd56017 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -1,6 +1,7 @@ # Import/Export passes for going from `sdy.sharding`s to `mhlo.sharding`s and vice versa. load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -121,6 +122,25 @@ cc_library( ], ) +xla_cc_test( + name = "mhlo_import_test", + srcs = ["mhlo_import_test.cc"], + deps = [ + ":mhlo_import", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/ir:register", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:test", + ], +) + cc_library( name = "shard_map_import", srcs = ["shard_map_import.cc"], diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index 02553d14e9a98..50177a1268f9e 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -441,6 +441,9 @@ TensorShardingAttr convertToSdySharding( // break it when we find common mesh axes. while (product < localAxisSize) { MeshAxisAttr axisAttr = globalMesh.getAxes()[globalAxisIndex++]; + if (axisAttr.getSize() == 1) { + continue; + } globalAxes.push_back(AxisRefAttr::get(ctx, axisAttr.getName())); product *= axisAttr.getSize(); } diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc new file mode 100644 index 0000000000000..f1635a804ab33 --- /dev/null +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h" + +#include + +#include +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "tsl/platform/test.h" + +namespace mlir::sdy { + +namespace { + +TEST(MhloImportTest, SkipFirstAxisOfSize1) { + MLIRContext context; + loadAllRequiredDialects(&context); + SmallVector axes; + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "x", 1)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "y", 4)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "z", 2)); + auto mesh = sdy::MeshAttr::get(&context, axes); + + TensorShardingAttr sharding = xla::sdy::convertToSdySharding( + /*hloSharding=*/xla::HloSharding::IotaTile({4, 2}), + /*globalMesh=*/mesh, + /*deviceIdToMaximalMeshName=*/ + llvm::SmallDenseMap(), /*rank=*/2, + /*openDims=*/true); + EXPECT_EQ(attributeToString(sharding), + "#sdy.sharding<@mesh, [{\"y\", ?}, {\"z\", ?}]>"); +} + +// As above, but the middle axis is the one with size 1. +TEST(MhloImportTest, SkipSecondAxisOfSize1) { + MLIRContext context; + loadAllRequiredDialects(&context); + SmallVector axes; + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "y", 4)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "x", 1)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "z", 2)); + auto mesh = sdy::MeshAttr::get(&context, axes); + + TensorShardingAttr sharding = xla::sdy::convertToSdySharding( + /*hloSharding=*/xla::HloSharding::IotaTile({4, 2}), + /*globalMesh=*/mesh, + /*deviceIdToMaximalMeshName=*/ + llvm::SmallDenseMap(), /*rank=*/2, + /*openDims=*/true); + EXPECT_EQ(attributeToString(sharding), + "#sdy.sharding<@mesh, [{\"y\", ?}, {\"z\", ?}]>"); +} + +} // namespace +} // namespace mlir::sdy