From 02bccbeeffcb5231437a8fd3b07935938b034b7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 2 Feb 2025 13:40:08 +0100 Subject: [PATCH] Prototype conversion --- BUILD | 1 + src/enzyme_ad/jax/Passes/MPIToStableHLO.cpp | 322 ++++++++++++++++++++ src/enzyme_ad/jax/Passes/Passes.td | 15 + 3 files changed, 338 insertions(+) create mode 100644 src/enzyme_ad/jax/Passes/MPIToStableHLO.cpp diff --git a/BUILD b/BUILD index d8aae1d56..7cf2d0907 100644 --- a/BUILD +++ b/BUILD @@ -55,6 +55,7 @@ cc_binary( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:MPIDialect", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:NVGPUDialect", "@llvm-project//mlir:OpenMPDialect", diff --git a/src/enzyme_ad/jax/Passes/MPIToStableHLO.cpp b/src/enzyme_ad/jax/Passes/MPIToStableHLO.cpp new file mode 100644 index 000000000..c873c691d --- /dev/null +++ b/src/enzyme_ad/jax/Passes/MPIToStableHLO.cpp @@ -0,0 +1,322 @@ +//===- MPIToStableHLO.cpp - Convert MPI ops to StableHLO custom_call ops --===// +// +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MPI ops to StableHLO custom_call ops. +// +//===----------------------------------------------------------------------===// + +// NOTE we should be targetting libmpitrampoline ABI, since XLA already adds it +// as a dependency and it fix some issues with MPI ABI compatibility. In +// particular, MPI types defined in the standard are not ABI-stable, so we must +// use `uintptr_t` instead of `MPI_Comm`, `MPI_Request`, etc... +// TODO or can we use libmpitrampoline's ABI directly? i.e. `MPIABI_Comm`, ... + +#include "Passes.h" + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERMPITOSTABLEHLOPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::mpi; +using namespace stablehlo; + +namespace { +struct InitOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::InitOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_init"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +struct FinalizeOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::FinalizeOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_finalize"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +// struct CommWorldOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::CommWorldOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_comm_world"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +struct CommRankOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::CommRankOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_comm_rank"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +// struct CommSizeOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::CommSizeOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_comm_size"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +// struct CommSplitOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::CommSplitOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_comm_split"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +struct SendOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::SendOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_send"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +struct RecvOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::RecvOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_recv"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +// struct ISendOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::ISendOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_isend"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +// struct IRecvOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::IRecvOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_irecv"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +// struct BarrierOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::BarrierOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_barrier"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +// struct WaitOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::WaitOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_wait"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +// struct AllReduceOpLowering : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; + +// LogicalResult matchAndRewrite(mpi::AllReduceOp op, PatternRewriter &rewriter) const override { +// rewriter.replaceOpWithNewOp( +// op, op.getResultTypes(), op.getOperands(), +// rewriter.getStringAttr("mpi_allreduce"), +// rewriter.getBoolAttr(false), +// rewriter.getDictionaryAttr({}), +// CustomCallApiVersionAttr::get( +// rewriter.getContext(), +// mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), +// nullptr, ValueRange(), ValueRange(), ValueRange()); +// return success(); +// } +// }; + +struct RetvalCheckOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::RetvalCheckOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_retval_check"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +struct ErrorClassOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mpi::ErrorClassOp op, PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), op.getOperands(), + rewriter.getStringAttr("mpi_error_class"), + rewriter.getBoolAttr(false), + rewriter.getDictionaryAttr({}), + CustomCallApiVersionAttr::get( + rewriter.getContext(), + mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI), + nullptr, ValueRange(), ValueRange(), ValueRange()); + return success(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct LowerMPIToStableHLOPass : public LowerMPIToStableHLOPassBase { +using LowerMPIToStableHLOPassBase::LowerMPIToStableHLOPassBase; + void runOnOperation() override { + ConversionTarget target(getContext()); + + // XLA can't handle MPI ops, so we must convert all MPI ops to `stablehlo.custom_call` ops + target.addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + patterns.add< + InitOpLowering, + FinalizeOpLowering, + CommRankOpLowering, + SendOpLowering, + RecvOpLowering, + RetvalCheckOpLowering, + ErrorClassOpLowering + >(&getContext()); + + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + } + } +} +} // namespace diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index 4cf563335..d3ce326f3 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -421,4 +421,19 @@ def EnzymeLiftControlFlowToSCFPass : Pass<"enzyme-lift-cf-to-scf"> { "func::FuncDialect"]; } +//===----------------------------------------------------------------------===// +// MPIToStableHLO +//===----------------------------------------------------------------------===// + +def LowerMPIToStableHLOPass : Pass<"convert-mpi-to-stablehlo"> { + let summary = "Lower MPI ops to the StableHLO custom calls"; + let dependentDialects = [ + "mpi::MPIDialect", + "stablehlo::StablehloDialect" + ]; + + // TODO do we need to add options for getting libmpi path? + let options = []; +} + #endif