Skip to content

Commit

Permalink
Prototype conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 2, 2025
1 parent 5e2e43a commit 02bccbe
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 0 deletions.
1 change: 1 addition & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ cc_binary(
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:MPIDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:OpenMPDialect",
Expand Down
322 changes: 322 additions & 0 deletions src/enzyme_ad/jax/Passes/MPIToStableHLO.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
//===- MPIToStableHLO.cpp - Convert MPI ops to StableHLO custom_call ops --===//
//
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert MPI ops to StableHLO custom_call ops.
//
//===----------------------------------------------------------------------===//

// NOTE we should be targetting libmpitrampoline ABI, since XLA already adds it
// as a dependency and it fix some issues with MPI ABI compatibility. In
// particular, MPI types defined in the standard are not ABI-stable, so we must
// use `uintptr_t` instead of `MPI_Comm`, `MPI_Request`, etc...
// TODO or can we use libmpitrampoline's ABI directly? i.e. `MPIABI_Comm`, ...

#include "Passes.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

#include "mlir/Dialect/MPI/IR/MPI.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_LOWERMPITOSTABLEHLOPASS
#include "src/enzyme_ad/jax/Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

using namespace mlir;
using namespace mlir::mpi;
using namespace stablehlo;

namespace {
struct InitOpLowering : public OpRewritePattern<mpi::InitOp> {
using OpRewritePattern<mpi::InitOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::InitOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_init"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

struct FinalizeOpLowering : public OpRewritePattern<mpi::FinalizeOp> {
using OpRewritePattern<mpi::FinalizeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::FinalizeOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_finalize"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

// struct CommWorldOpLowering : public OpRewritePattern<mpi::CommWorldOp> {
// using OpRewritePattern<mpi::CommWorldOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::CommWorldOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_comm_world"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

struct CommRankOpLowering : public OpRewritePattern<mpi::CommRankOp> {
using OpRewritePattern<mpi::CommRankOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::CommRankOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_comm_rank"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

// struct CommSizeOpLowering : public OpRewritePattern<mpi::CommSizeOp> {
// using OpRewritePattern<mpi::CommSizeOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::CommSizeOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_comm_size"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

// struct CommSplitOpLowering : public OpRewritePattern<mpi::CommSplitOp> {
// using OpRewritePattern<mpi::CommSplitOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::CommSplitOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_comm_split"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

struct SendOpLowering : public OpRewritePattern<mpi::SendOp> {
using OpRewritePattern<mpi::SendOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::SendOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_send"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

struct RecvOpLowering : public OpRewritePattern<mpi::RecvOp> {
using OpRewritePattern<mpi::RecvOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::RecvOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_recv"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

// struct ISendOpLowering : public OpRewritePattern<mpi::ISendOp> {
// using OpRewritePattern<mpi::ISendOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::ISendOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_isend"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

// struct IRecvOpLowering : public OpRewritePattern<mpi::IRecvOp> {
// using OpRewritePattern<mpi::IRecvOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::IRecvOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_irecv"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

// struct BarrierOpLowering : public OpRewritePattern<mpi::BarrierOp> {
// using OpRewritePattern<mpi::BarrierOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::BarrierOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_barrier"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

// struct WaitOpLowering : public OpRewritePattern<mpi::WaitOp> {
// using OpRewritePattern<mpi::WaitOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::WaitOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_wait"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

// struct AllReduceOpLowering : public OpRewritePattern<mpi::AllReduceOp> {
// using OpRewritePattern<mpi::AllReduceOp>::OpRewritePattern;

// LogicalResult matchAndRewrite(mpi::AllReduceOp op, PatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
// op, op.getResultTypes(), op.getOperands(),
// rewriter.getStringAttr("mpi_allreduce"),
// rewriter.getBoolAttr(false),
// rewriter.getDictionaryAttr({}),
// CustomCallApiVersionAttr::get(
// rewriter.getContext(),
// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
// nullptr, ValueRange(), ValueRange(), ValueRange());
// return success();
// }
// };

struct RetvalCheckOpLowering : public OpRewritePattern<mpi::RetvalCheckOp> {
using OpRewritePattern<mpi::RetvalCheckOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::RetvalCheckOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_retval_check"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

struct ErrorClassOpLowering : public OpRewritePattern<mpi::ErrorClassOp> {
using OpRewritePattern<mpi::ErrorClassOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mpi::ErrorClassOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<stablehlo::CustomCallOp>(
op, op.getResultTypes(), op.getOperands(),
rewriter.getStringAttr("mpi_error_class"),
rewriter.getBoolAttr(false),
rewriter.getDictionaryAttr({}),
CustomCallApiVersionAttr::get(
rewriter.getContext(),
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI),
nullptr, ValueRange(), ValueRange(), ValueRange());
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
struct LowerMPIToStableHLOPass : public LowerMPIToStableHLOPassBase<LowerMPIToStableHLOPass> {
using LowerMPIToStableHLOPassBase::LowerMPIToStableHLOPassBase;
void runOnOperation() override {
ConversionTarget target(getContext());

// XLA can't handle MPI ops, so we must convert all MPI ops to `stablehlo.custom_call` ops
target.addIllegalDialect<MPI::MPIDialect>();

RewritePatternSet patterns(&getContext());
patterns.add<
InitOpLowering,
FinalizeOpLowering,
CommRankOpLowering,
SendOpLowering,
RecvOpLowering,
RetvalCheckOpLowering,
ErrorClassOpLowering
>(&getContext());

if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
}
}
}
} // namespace
15 changes: 15 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -421,4 +421,19 @@ def EnzymeLiftControlFlowToSCFPass : Pass<"enzyme-lift-cf-to-scf"> {
"func::FuncDialect"];
}

//===----------------------------------------------------------------------===//
// MPIToStableHLO
//===----------------------------------------------------------------------===//

def LowerMPIToStableHLOPass : Pass<"convert-mpi-to-stablehlo"> {
let summary = "Lower MPI ops to the StableHLO custom calls";
let dependentDialects = [
"mpi::MPIDialect",
"stablehlo::StablehloDialect"
];

// TODO do we need to add options for getting libmpi path?
let options = [];
}

#endif

0 comments on commit 02bccbe

Please sign in to comment.