Skip to content

Commit

Permalink
#sdy Add support for exporting nested ManualComputations in MHLO export.
Browse files Browse the repository at this point in the history
Also adds some extra debug dumps during import/export.

PiperOrigin-RevId: 674248454
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Sep 13, 2024
1 parent bad4cc8 commit 8a450b6
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 57 deletions.
18 changes: 11 additions & 7 deletions xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class ExportMhloShardingsPass
HloSharding convertToHloSharding(
TensorShardingAttr sdySharding,
std::function<MeshAttr(TensorShardingAttr)> getMeshAttr,
ArrayRef<AxisRefAttr> manualAxes) {
ArrayRef<StringAttr> manualAxes) {
MeshAttr mesh = getMeshAttr(sdySharding);

// If there are no axes, convert to:
Expand All @@ -250,6 +250,10 @@ HloSharding convertToHloSharding(
SmallVector<OpSharding::Type> types;
int64_t shardedPos = 0;

if (mesh.getAxes().size() == manualAxes.size()) {
return HloSharding::Manual();
}

// Iterate the dim shardings.
for (auto [index, dimSharding] :
llvm::enumerate(sdySharding.getDimShardings())) {
Expand All @@ -263,9 +267,10 @@ HloSharding convertToHloSharding(
if (!manualAxes.empty()) {
types.push_back(OpSharding::MANUAL);
int64_t& manualDim = tileAssignmentDims.emplace_back(1);
for (AxisRefAttr axisRef : manualAxes) {
manualDim *= axisRef.getSize(mesh);
axisRefToShardedPos[axisRef] = shardedPos++;
mlir::MLIRContext* context = sdySharding.getContext();
for (StringRef manualAxis : manualAxes) {
manualDim *= mesh.getAxisSize(manualAxis);
axisRefToShardedPos[AxisRefAttr::get(context, manualAxis)] = shardedPos++;
}
}

Expand Down Expand Up @@ -317,12 +322,11 @@ StringAttr convertToHloShardingAttr(
Operation* op, ArrayRef<TensorShardingAttr> shardings,
std::function<MeshAttr(TensorShardingAttr)> getMeshAttr,
std::function<StringAttr(const HloSharding&)> getStringAttr,
ArrayRef<AxisRefAttr> manualAxes) {
ArrayRef<StringAttr> manualAxes) {
assert(shardings.size() == op->getNumResults());
if (op->getNumResults() == 1) {
TensorShardingAttr sdySharding = shardings.front();
return getStringAttr(
convertToHloSharding(sdySharding, getMeshAttr, manualAxes));
convertToHloSharding(shardings.front(), getMeshAttr, manualAxes));
}

SmallVector<HloSharding> newShardings;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ HloSharding convertToHloSharding(
mlir::sdy::TensorShardingAttr sdySharding,
std::function<mlir::sdy::MeshAttr(mlir::sdy::TensorShardingAttr)>
getMeshAttr,
mlir::ArrayRef<mlir::sdy::AxisRefAttr> manualAxes = {});
mlir::ArrayRef<mlir::StringAttr> manualAxes = {});

// Convert the `shardings` into a `StringAttr` representing `xla::HloSharding`
// for the given `op`.
Expand All @@ -45,7 +45,7 @@ mlir::StringAttr convertToHloShardingAttr(
std::function<mlir::sdy::MeshAttr(mlir::sdy::TensorShardingAttr)>
getMeshAttr,
std::function<mlir::StringAttr(const HloSharding&)> getStringAttr,
mlir::ArrayRef<mlir::sdy::AxisRefAttr> manualAxes = {});
mlir::ArrayRef<mlir::StringAttr> manualAxes = {});

// Creates a pass that converts the shardings from `kShardingAttr` to
// `kXlaShardingAttr` and removes mesh symbols. Fully or partially manual
Expand Down
108 changes: 74 additions & 34 deletions xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <utility>

#include "absl/log/check.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
Expand Down Expand Up @@ -80,19 +81,25 @@ using ::mlir::mhlo::CopyOp;
using ::mlir::mhlo::CustomCallOp;

namespace sdy = ::mlir::sdy;
using sdy::AxisRefAttr;
using sdy::kShardingAttr;
using sdy::ManualComputationOp;
using sdy::MeshAttr;
using sdy::SdyDialect;
using sdy::TensorShardingAttr;
using sdy::TensorShardingPerValueAttr;

// Mapping from ManualComputationOp to all manual axes it's nested in.
using ManualComputationToParentManualAxes =
llvm::SmallDenseMap<ManualComputationOp, SmallVector<StringAttr>>;

class ManualComputationPattern
: public OpConversionPattern<ManualComputationOp> {
public:
explicit ManualComputationPattern(MLIRContext* context)
: OpConversionPattern<ManualComputationOp>(context) {
explicit ManualComputationPattern(
MLIRContext* context,
const ManualComputationToParentManualAxes& parentManualCompAxes)
: OpConversionPattern<ManualComputationOp>(context),
parentManualCompAxes(parentManualCompAxes) {
// We call this function so that MLIR applies the pattern to any
// ManualComputationOp that uses another ManualComputationOp.
setHasBoundedRewriteRecursion(true);
Expand All @@ -118,17 +125,22 @@ class ManualComputationPattern
StringRef meshName = inOutShardings.begin()->getMeshName();
MeshAttr mesh = mlir::sdy::getMeshAttr(op, meshName);
CHECK(mesh);

MLIRContext* context = rewriter.getContext();

// The axes that are manual inside `op`'s region.
SmallVector<StringAttr> regionManualAxes(op.getManualAxes().begin(),
op.getManualAxes().end());
mlir::ArrayRef<StringAttr> parentManualAxes;
if (parentManualCompAxes.contains(op)) {
parentManualAxes = parentManualCompAxes.at(op);
regionManualAxes.append(parentManualAxes.begin(), parentManualAxes.end());
}

// If `fullyManual` is true, all axes are manual. Otherwise, partial axes
// are manual and other axes are free (sharded or replicated) in the body of
// the manual computation.
bool fullyManual = mesh.getAxes().size() == op.getManualAxes().size();

MLIRContext* context = rewriter.getContext();
SmallVector<AxisRefAttr> manualAxes;
llvm::transform(op.getManualAxes(), std::back_inserter(manualAxes),
[&](StringAttr manualAxis) {
return AxisRefAttr::get(context, manualAxis);
});
bool fullyManual = mesh.getAxes().size() == regionManualAxes.size();

std::function<StringAttr(const HloSharding&)> getStringAttr =
[&](const HloSharding& hloSharding) {
Expand All @@ -146,7 +158,7 @@ class ManualComputationPattern
TensorShardingAttr fullyOpen =
TensorShardingAttr::getFullyOpen(context, rank, meshName);
HloSharding hloSharding =
convertToHloSharding(fullyOpen, getMeshAttr, manualAxes);
convertToHloSharding(fullyOpen, getMeshAttr, regionManualAxes);
return getStringAttr(hloSharding);
};

Expand All @@ -164,27 +176,37 @@ class ManualComputationPattern
// We export the shardings in the body.
if (fullyManual) {
// All operations in the body have fully manual sharding.
op.getBody().front().walk([&](Operation* opInBody) {
opInBody->setAttr(kXlaShardingAttr, fullyManualSharding);
// Remove the possible fully replicated sdy.sharding attribute.
opInBody->removeAttr(kShardingAttr);
});
op.getBody().front().walk<mlir::WalkOrder::PreOrder>(
[&](Operation* opInBody) {
if (mlir::isa<ManualComputationOp>(opInBody)) {
return mlir::WalkResult::skip();
}
opInBody->setAttr(kXlaShardingAttr, fullyManualSharding);
// Remove the possible fully replicated sdy.sharding attribute.
opInBody->removeAttr(kShardingAttr);
return mlir::WalkResult::advance();
});
} else {
// All operations in the body must be sharded or replicated along free
// axes. If an operation does not have sharding annotation, it is fully
// replicated along free axes.
op.getBody().front().walk([&](Operation* opInBody) {
op.getBody().front().walk<mlir::WalkOrder::PreOrder>([&](Operation*
opInBody) {
if (mlir::isa<ManualComputationOp>(opInBody)) {
return mlir::WalkResult::skip();
}
TensorShardingPerValueAttr shardingPerValue =
opInBody->getAttrOfType<TensorShardingPerValueAttr>(kShardingAttr);
if (!shardingPerValue) {
shardingPerValue = TensorShardingPerValueAttr::getFullyOpen(
context, opInBody->getResultTypes(), meshName);
}
opInBody->setAttr(
kXlaShardingAttr,
convertToHloShardingAttr(opInBody, shardingPerValue.getShardings(),
getMeshAttr, getStringAttr, manualAxes));
opInBody->setAttr(kXlaShardingAttr,
convertToHloShardingAttr(
opInBody, shardingPerValue.getShardings(),
getMeshAttr, getStringAttr, regionManualAxes));
opInBody->removeAttr(kShardingAttr);
return mlir::WalkResult::advance();
});
}

Expand All @@ -197,9 +219,9 @@ class ManualComputationPattern
llvm::zip_equal(adaptor.getOperands(), op.getBody().getArgumentTypes(),
adaptor.getInShardings().getShardings())) {
auto copy = rewriter.create<CopyOp>(loc, globalOperand);
copy->setAttr(kShardingAttr,
TensorShardingPerValueAttr::get(context, inSharding));

copy->setAttr(kXlaShardingAttr,
getStringAttr(convertToHloSharding(inSharding, getMeshAttr,
parentManualAxes)));
if (!fullyManual) {
fullToShardAttributes.back() = rewriter.getNamedAttr(
kXlaShardingAttr, partialManualSharding(localArgumentType));
Expand All @@ -209,9 +231,8 @@ class ManualComputationPattern
fullToShardResults.push_back(fullToShard.getResult(0));
}

Operation* terminator = getBodyTerminator(op);
Operation* terminator = getBodyTerminator(adaptor);
rewriter.inlineBlockBefore(&op.getBody().front(), op, fullToShardResults);

// Add custom_call @SPMDShardToFullShape and copy for each operand of
// terminator.
for (auto [terminatorOperand, opResult, outSharding] :
Expand All @@ -222,18 +243,26 @@ class ManualComputationPattern
fullyManual
? fullyManualSharding
: partialManualSharding(copy.getResult().getType()));

shardToFullAttributes.back() = rewriter.getNamedAttr(
kShardingAttr, TensorShardingPerValueAttr::get(context, outSharding));
kXlaShardingAttr, getStringAttr(convertToHloSharding(
outSharding, getMeshAttr, parentManualAxes)));
auto shardToFull = rewriter.create<CustomCallOp>(
loc, opResult.getType(), copy.getResult(), shardToFullAttributes);
rewriter.replaceAllUsesWith(opResult, shardToFull.getResult(0));
}

rewriter.eraseOp(terminator);
rewriter.eraseOp(op);
// NOTE: we can't just `rewriter.eraseOp` the terminator, because in the
// next round of the conversion pattern running, operands of the terminator
// still have the terminator as a use. For whatever reason, maybe a bug in
// MLIR, we need to explicitly remove the operands from the terminator and
// then that makes sure the operands uses doesn't include the terminator.
terminator->setOperands({});
rewriter.eraseOp(terminator);
return mlir::success();
}

private:
const ManualComputationToParentManualAxes& parentManualCompAxes;
};

class ShardMapExportPass
Expand All @@ -243,6 +272,17 @@ class ShardMapExportPass

private:
void runOnOperation() final {
ManualComputationToParentManualAxes parentManualCompAxes;
ModuleOp module = getOperation();
module->walk<mlir::WalkOrder::PreOrder>([&](ManualComputationOp op) {
if (auto parentOp = op->getParentOfType<ManualComputationOp>()) {
SmallVector<StringAttr>& parentAxes = parentManualCompAxes[op];
parentAxes = parentManualCompAxes[parentOp];
parentAxes.insert(parentAxes.end(), parentOp.getManualAxes().begin(),
parentOp.getManualAxes().end());
}
});

MLIRContext& context = getContext();
mlir::ConversionTarget target(context);
target.addIllegalOp<ManualComputationOp>();
Expand All @@ -253,8 +293,8 @@ class ShardMapExportPass
// be nested within an MHLO op, e.g., a while loop.
target.addLegalDialect<mlir::func::FuncDialect, mlir::mhlo::MhloDialect>();
mlir::RewritePatternSet patterns(&context);
patterns.add<ManualComputationPattern>(&context);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
patterns.add<ManualComputationPattern>(&context, parentManualCompAxes);
if (mlir::failed(mlir::applyPartialConversion(module, target,
std::move(patterns)))) {
signalPassFailure();
}
Expand All @@ -268,7 +308,7 @@ class ShardMapExportPass
}

void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<SdyDialect>();
registry.insert<SdyDialect, mlir::mhlo::MhloDialect>();
}
};

Expand Down
8 changes: 4 additions & 4 deletions xla/service/spmd/shardy/shardy_xla_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ absl::StatusOr<bool> ShardyXLA::Run(

mlir::PassManager pm(mlirContext.get());
pm.enableVerifier(enableVerifier);
pm.addPass(mlir::sdy::createSaveModuleOpPass(shardyDir,
"sdy_module_before_xla_import"));
bool useTupleArgs = false;
mlir::DictionaryAttr moduleFrontendAttrs = getFrontendAttrs(*mlirModule);
if (moduleFrontendAttrs && moduleFrontendAttrs.get(kUseTupleArgs)) {
Expand Down Expand Up @@ -366,8 +368,6 @@ absl::StatusOr<bool> ShardyXLA::Run(
originalParamIndexToFlattenedNum,
useTupleArgs);

pm.addPass(
mlir::sdy::createSaveModuleOpPass(shardyDir, "sdy_module_after_import"));
if (runSdyShardingPropagation) {
// Shardy is currently operating on stablehlo, since this is what JAX
// emits. Long term shardy will be fully dialect agnostic, and both mhlo
Expand All @@ -381,8 +381,8 @@ absl::StatusOr<bool> ShardyXLA::Run(
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
}
addMhloExportPipeline(pm);
pm.addPass(
mlir::sdy::createSaveModuleOpPass(shardyDir, "sdy_module_after_export"));
pm.addPass(mlir::sdy::createSaveModuleOpPass(shardyDir,
"sdy_module_after_xla_export"));
tsl::StatusScopedDiagnosticHandler diagnosticHandler(mlirContext.get());
TF_RETURN_IF_ERROR(diagnosticHandler.consumeStatus(pm.run(*mlirModule)));

Expand Down
Loading

0 comments on commit 8a450b6

Please sign in to comment.