Skip to content

Commit

Permalink
#sdy Add sdy round trip shard map import.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675206996
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 16, 2024
1 parent a03a49b commit 590cd6f
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 20 deletions.
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,10 @@ cc_library(
deps = [
":constants",
"//xla/mlir_hlo",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:IR",
Expand Down Expand Up @@ -152,6 +154,7 @@ xla_cc_binary(
"//xla/service/spmd/shardy/sdy_round_trip:import_shardings",
"//xla/service/spmd/shardy/sdy_round_trip:pipelines",
"//xla/service/spmd/shardy/sdy_round_trip:shard_map_export",
"//xla/service/spmd/shardy/sdy_round_trip:shard_map_import",
"//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo",
"//xla/service/spmd/shardy/sdy_round_trip/test_utils:testing_pipeline",
"@llvm-project//mlir:AllPassesAndDialects",
Expand Down
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/sdy_opt_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h"
#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h"
#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h"
#include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h"
#include "xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.h"

Expand Down Expand Up @@ -70,6 +71,7 @@ int main(int argc, char** argv) {
xla::sdy::registerSdyRoundTripExportOpsPass();
xla::sdy::registerSdyRoundTripExportPipeline();
xla::sdy::registerSdyRoundTripShardMapExportPass();
xla::sdy::registerSdyRoundTripShardMapImportPass();
xla::sdy::registerSdyRoundTripImportPipeline();
xla::sdy::registerSdyRoundTripTestingPipeline();

Expand Down
21 changes: 21 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,27 @@ cc_library(
],
)

cc_library(
name = "shard_map_import",
srcs = ["shard_map_import.cc"],
hdrs = ["shard_map_import.h"],
deps = [
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "pipelines",
srcs = ["pipelines.cc"],
Expand Down
20 changes: 0 additions & 20 deletions xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,6 @@ using ::mlir::sdy::MeshAttr;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;

// Parses `stringAttr` to an attribute of type `AttrTy`.
//
// NOTE: assumes `stringAttr` is of type `StringAttr`.
template <typename AttrTy>
AttrTy parseStringAttr(Attribute stringAttr) {
std::string value;
std::string error;
CHECK(absl::CUnescape(mlir::cast<StringAttr>(stringAttr).getValue(), &value,
&error))
<< error;
return mlir::cast<AttrTy>(
mlir::parseAttribute(value, stringAttr.getContext()));
}

// Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`.
template <typename AttrTy>
AttrTy parseStringAttr(DictionaryAttr dictAttr, llvm::StringRef attrName) {
return parseStringAttr<AttrTy>(dictAttr.get(attrName));
}

// Builds the shardings coming from Shardy previously. This means
// the module was exported from Shardy and we are now round-tripping back.
// This should happen after the meshes were created from the `ModuleOp` attrs
Expand Down
164 changes: 164 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h"

#include <memory>
#include <utility>

#include "absl/log/check.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/DialectConversion.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"

namespace xla {
namespace sdy {

namespace {

using ::mlir::MLIRContext;
using ::mlir::ModuleOp;
using ::mlir::OpConversionPattern;
using ::mlir::StringRef;
using ::mlir::SymbolTable;
using ::mlir::func::FuncOp;
using ::mlir::stablehlo::CustomCallOp;

namespace sdy = ::mlir::sdy;

// Converts `CustomCallOp`s called `@xla.sdy.ManualComputation` with in/out
// shardings and manual axes as frontend attrs to `ManualComputationOp`s.
class ManualComputationPattern : public OpConversionPattern<CustomCallOp> {
public:
explicit ManualComputationPattern(MLIRContext* context,
const SymbolTable& symbolTable)
: OpConversionPattern<CustomCallOp>(context), symbolTable(symbolTable) {}

mlir::LogicalResult matchAndRewrite(
CustomCallOp customCallOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter& rewriter) const override {
if (customCallOp.getCallTargetName() !=
kManualComputationCustomCallTargetName) {
return mlir::failure();
}

CHECK_EQ(customCallOp.getCalledComputations().size(), 1);
auto shmapBodyFunc =
symbolTable.lookup<FuncOp>((*customCallOp.getCalledComputations()
.getAsRange<mlir::FlatSymbolRefAttr>()
.begin())
.getValue());
if (shmapBodyFunc.empty()) {
return customCallOp->emitOpError(
"expected a unique FuncOp per "
"@xla.sdy.ManualComputation custom call. Were "
"functions maybe somehow shared/de-duped between "
"two ManualComputations?");
}

mlir::DictionaryAttr frontendAttrs = getFrontendAttrs(customCallOp);
CHECK(frontendAttrs);
auto manualComputationOp =
rewriter.replaceOpWithNewOp<sdy::ManualComputationOp>(
customCallOp, customCallOp->getResultTypes(),
customCallOp->getOperands(),
parseStringAttr<sdy::TensorShardingPerValueAttr>(frontendAttrs,
kInShardings),
parseStringAttr<sdy::TensorShardingPerValueAttr>(frontendAttrs,
kOutShardings),
parseStringAttr<sdy::ManualAxesAttr>(frontendAttrs, kManualAxes));
sdy::inlineRegionAndConvertTerminatorOp<sdy::ReturnOp>(
shmapBodyFunc.getBody(), manualComputationOp.getRegion(), rewriter);
rewriter.eraseOp(shmapBodyFunc);
return mlir::success();
}

private:
const SymbolTable& symbolTable;
};

class SdyRoundTripShardMapImportPass
: public mlir::PassWrapper<SdyRoundTripShardMapImportPass,
mlir::OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SdyRoundTripShardMapImportPass)

private:
void runOnOperation() final {
ModuleOp module = getOperation();
mlir::SymbolTableCollection symbolTableCollection;
SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module);
MLIRContext& context = getContext();
mlir::ConversionTarget target(context);
target.addDynamicallyLegalOp<CustomCallOp>([](CustomCallOp op) {
return op.getCallTargetName() != kManualComputationCustomCallTargetName;
});
target.addLegalOp<sdy::ManualComputationOp, sdy::ReturnOp>();
mlir::RewritePatternSet patterns(&context);
patterns.add<ManualComputationPattern>(&context, symbolTable);
if (mlir::failed(mlir::applyPartialConversion(module, target,
std::move(patterns)))) {
signalPassFailure();
}
}

StringRef getArgument() const override {
return "xla-sdy-round-trip-shard-map-import";
}

StringRef getDescription() const override {
return "converts CustomCalls called @xla.sdy.manual_computation_body "
"with in/out shardings and manual axes as frontend attrs to a "
"`ManualComputationOp`";
}
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<sdy::SdyDialect>();
}
};

} // namespace

void registerSdyRoundTripShardMapImportPass() {
mlir::registerPass(createSdyRoundTripShardMapImportPass);
}

std::unique_ptr<mlir::Pass> createSdyRoundTripShardMapImportPass() {
return std::make_unique<SdyRoundTripShardMapImportPass>();
}

} // namespace sdy
} // namespace xla
37 changes: 37 additions & 0 deletions xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_IMPORT_H_
#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_IMPORT_H_

#include <memory>

#include "mlir/Pass/Pass.h"

namespace xla {
namespace sdy {

// Creates the pass that converts a `CustomCallOp` called
// `kManualComputationBodyFuncName` with in/out shardings and manual
// axes as frontend attrs to a `ManualComputationOp`.
std::unique_ptr<mlir::Pass> createSdyRoundTripShardMapImportPass();

// Registers the xla-sdy-round-trip-shard-map-import pass.
void registerSdyRoundTripShardMapImportPass();

} // namespace sdy
} // namespace xla

#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_SHARD_MAP_IMPORT_H_
Loading

0 comments on commit 590cd6f

Please sign in to comment.