Skip to content

Commit

Permalink
#sdy Add sdy-round-trip-shard-map-export Pass.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674303254
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 13, 2024
1 parent bff7fd7 commit 1135d3c
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 10 deletions.
1 change: 1 addition & 0 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ xla_cc_binary(
"//xla/service/spmd/shardy/sdy_round_trip:export_shardings",
"//xla/service/spmd/shardy/sdy_round_trip:import_shardings",
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
"//xla/service/spmd/shardy/sdy_round_trip:shard_map_export",
"//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo",
"//xla/service/spmd/shardy/sdy_round_trip/test_utils:testing_pipeline",
"@llvm-project//mlir:AllPassesAndDialects",
Expand Down
20 changes: 20 additions & 0 deletions xla/service/spmd/shardy/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ inline constexpr llvm::StringRef kPythonIntegrationComplete =
// XLA pipeline, so no HLO<->MLIR round-tripping.
inline constexpr llvm::StringRef kUseTupleArgs = "xla.sdy.use_tuple_args";

// Attribute name for the in shardings of a `ManualComputationOp`.
inline constexpr llvm::StringRef kInShardings = "xla.sdy.in_shardings";

// Attribute name for the out shardings of a `ManualComputationOp`.
inline constexpr llvm::StringRef kOutShardings = "xla.sdy.out_shardings";

// Attribute name for the manual axes of a `ManualComputationOp`.
inline constexpr llvm::StringRef kManualAxes = "xla.sdy.manual_axes";

// The target name of the custom call that will store the various attrs of a
// `ManualComputationOp` and a reference to a `FuncOp` that is the body of the
// original `ManualComputationOp`.
inline constexpr llvm::StringRef kManualComputationCustomCallTargetName =
"xla.sdy.ManualComputation";

// The function name of the of the body of a `ManualComputationOp` during Shardy
// round tripping. Used
inline constexpr llvm::StringRef kManualComputationBodyFuncName =
"xla.sdy.manual_computation_body";

// The name of the global mesh.
inline constexpr llvm::StringRef kGlobalMeshName = "mesh";

Expand Down
2 changes: 1 addition & 1 deletion xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) {
// `mhlo.constant` (which is foldable), therefore greedy pattern rewriters
// shouldn't be applied before converting to HLO as they apply folding.
pm.addPass(createExportOpsPass());
pm.addPass(createShardMapExportPass());
pm.addPass(createMhloRoundTripShardMapExportPass());
pm.addPass(createExportMhloShardingsPass());
}

Expand Down
10 changes: 6 additions & 4 deletions xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ class ShardMapExportPass
}
}

StringRef getArgument() const override { return "xla-sdy-shard-map-export"; }
StringRef getArgument() const override {
return "xla-mhlo-round-trip-shard-map-export";
}

StringRef getDescription() const override {
return "Replaces sdy::ManualComputationOp with the pattern that XLA "
Expand All @@ -314,12 +316,12 @@ class ShardMapExportPass

} // namespace

std::unique_ptr<mlir::Pass> createShardMapExportPass() {
std::unique_ptr<mlir::Pass> createMhloRoundTripShardMapExportPass() {
return std::make_unique<ShardMapExportPass>();
}

void registerShardMapExportPass() {
mlir::registerPass(createShardMapExportPass);
void registerMhloRoundTripShardMapExportPass() {
mlir::registerPass(createMhloRoundTripShardMapExportPass);
}

} // namespace sdy
Expand Down
6 changes: 3 additions & 3 deletions xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ namespace sdy {
// Creates a pass that converts the `ManualComputationOp` into the pattern that
// the XLA compiler recognizes. This pass also exports fully or partially manual
// shardings, while other shardings are processed in `ExportMhloShardingsPass`.
std::unique_ptr<mlir::Pass> createShardMapExportPass();
std::unique_ptr<mlir::Pass> createMhloRoundTripShardMapExportPass();

// Registers the xla-sdy-shard-map-export pass.
void registerShardMapExportPass();
// Registers the xla-mhlo-round-trip-shard-map-export pass.
void registerMhloRoundTripShardMapExportPass();

} // namespace sdy
} // namespace xla
Expand Down
4 changes: 3 additions & 1 deletion xla/service/spmd/shardy/sdy_opt_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"
#include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h"
#include "xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.h"

Expand All @@ -60,14 +61,15 @@ int main(int argc, char** argv) {

xla::sdy::registerMhloExportPipeline();
xla::sdy::registerMhloExportShardingsPass();
xla::sdy::registerShardMapExportPass();
xla::sdy::registerMhloRoundTripShardMapExportPass();
xla::sdy::registerExportOpsPass();

xla::sdy::registerSdyRoundTripMhloToHloToMhloPass();
xla::sdy::registerSdyRoundTripExportShardingsPass();
xla::sdy::registerSdyRoundTripImportShardingsPass();
xla::sdy::registerSdyRoundTripExportOpsPass();
xla::sdy::registerSdyRoundTripExportPipeline();
xla::sdy::registerSdyRoundTripShardMapExportPass();
xla::sdy::registerSdyRoundTripImportPipeline();
xla::sdy::registerSdyRoundTripTestingPipeline();

Expand Down
18 changes: 18 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,24 @@ cc_library(
],
)

cc_library(
name = "shard_map_export",
srcs = ["shard_map_export.cc"],
hdrs = ["shard_map_export.h"],
deps = [
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "pipelines",
srcs = ["pipelines.cc"],
Expand Down
127 changes: 127 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"

#include <memory>

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/DialectConversion.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"

namespace xla {
namespace sdy {

namespace {

using ::mlir::MLIRContext;
using ::mlir::ModuleOp;
using ::mlir::StringRef;
using ::mlir::func::FuncOp;

namespace stablehlo = ::mlir::stablehlo;
namespace sdy = ::mlir::sdy;

class SdyRoundTripShardMapExportPass
: public mlir::PassWrapper<SdyRoundTripShardMapExportPass,
mlir::OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SdyRoundTripShardMapExportPass)

void runOnOperation() final {
ModuleOp moduleOp = getOperation();
MLIRContext* context = moduleOp.getContext();
mlir::SymbolTableCollection symbolTableCollection;
mlir::SymbolTable& symbolTable =
symbolTableCollection.getSymbolTable(moduleOp);
auto rewriter = mlir::IRRewriter(context);
moduleOp->walk([&](sdy::ManualComputationOp manualComputation) {
rewriter.setInsertionPointToEnd(&moduleOp.getRegion().front());
auto funcOp = rewriter.create<FuncOp>(
manualComputation.getLoc(), kManualComputationBodyFuncName,
rewriter.getFunctionType(
manualComputation.getBody().getArgumentTypes(),
sdy::getBodyTerminatorOpOperandTypes(manualComputation)));
sdy::inlineRegionAndConvertTerminatorOp<mlir::func::ReturnOp>(
manualComputation.getBody(), funcOp.getBody());
mlir::StringAttr funcName = symbolTable.insert(funcOp);

rewriter.setInsertionPoint(manualComputation);
auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
manualComputation.getLoc(), manualComputation.getResultTypes(),
manualComputation->getOperands());
customCallOp.setCallTargetName(kManualComputationCustomCallTargetName);
customCallOp.setCalledComputationsAttr(
rewriter.getArrayAttr(mlir::FlatSymbolRefAttr::get(funcName)));
addFrontendAttribute(customCallOp, kInShardings,
manualComputation.getInShardings());
addFrontendAttribute(customCallOp, kOutShardings,
manualComputation.getOutShardings());
addFrontendAttribute(customCallOp, kManualAxes,
manualComputation.getManualAxesAttr());
rewriter.replaceOp(manualComputation, customCallOp->getResults());
});
}

StringRef getArgument() const override {
return "xla-sdy-round-trip-shard-map-export";
}

StringRef getDescription() const override {
return "Converts the body of a ManualComputationOp to a separate function "
"with a CustomCallOp of the same name referring to it. The "
"CustomCallOp saves the in/out shardings and manual axes as "
"frontend attrs for HLO round tripping.";
}
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<stablehlo::StablehloDialect>();
}
};

} // namespace

void registerSdyRoundTripShardMapExportPass() {
mlir::registerPass(createSdyRoundTripShardMapExportPass);
}

std::unique_ptr<mlir::Pass> createSdyRoundTripShardMapExportPass() {
return std::make_unique<SdyRoundTripShardMapExportPass>();
}

} // namespace sdy
} // namespace xla
36 changes: 36 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_EXPORT_H_
#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_EXPORT_H_

#include <memory>

#include "mlir/Pass/Pass.h"

namespace xla {
namespace sdy {

// Creates the pass that converts `ManualComputationOps` to a separate function
// and `CustomCallOp` for round tripping between HLO.
std::unique_ptr<mlir::Pass> createSdyRoundTripShardMapExportPass();

// Registers the xla-sdy-round-trip-shard-map-export pass.
void registerSdyRoundTripShardMapExportPass();

} // namespace sdy
} // namespace xla

#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_EXPORT_H_
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt %s -xla-sdy-shard-map-export 2>&1 | FileCheck %s
// RUN: sdy_opt %s -xla-mhlo-round-trip-shard-map-export 2>&1 | FileCheck %s

sdy.mesh @mesh_0 = <["a"=4, "b"=2]>
sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]>
Expand Down
Loading

0 comments on commit 1135d3c

Please sign in to comment.