-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Initial implementation of lowering CIR to MLIR #127835
base: main
Are you sure you want to change the base?
Changes from 2 commits
a03e9a9
8cad91d
738485c
a16303f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ static BackendAction | |
getBackendActionFromOutputType(CIRGenAction::OutputType Action) { | ||
switch (Action) { | ||
case CIRGenAction::OutputType::EmitCIR: | ||
case CIRGenAction::OutputType::EmitMLIR: | ||
assert(false && | ||
"Unsupported output type for getBackendActionFromOutputType!"); | ||
break; // Unreachable, but fall through to report that | ||
|
@@ -82,6 +83,7 @@ class CIRGenConsumer : public clang::ASTConsumer { | |
void HandleTranslationUnit(ASTContext &C) override { | ||
Gen->HandleTranslationUnit(C); | ||
mlir::ModuleOp MlirModule = Gen->getModule(); | ||
mlir::MLIRContext &MlirCtx = Gen->getMLIRContext(); | ||
switch (Action) { | ||
case CIRGenAction::OutputType::EmitCIR: | ||
if (OutputStream && MlirModule) { | ||
|
@@ -90,6 +92,15 @@ class CIRGenConsumer : public clang::ASTConsumer { | |
MlirModule->print(*OutputStream, Flags); | ||
} | ||
break; | ||
case CIRGenAction::OutputType::EmitMLIR: { | ||
auto LoweredMlirModule = lowerFromCIRToMLIR(MlirModule, MlirCtx); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please spell out the type. |
||
assert(OutputStream && "No output stream when lowering to MLIR!"); | ||
// FIXME: we cannot roundtrip prettyForm=true right now. | ||
mlir::OpPrintingFlags Flags; | ||
Flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/false); | ||
LoweredMlirModule->print(*OutputStream, Flags); | ||
break; | ||
} | ||
case CIRGenAction::OutputType::EmitLLVM: | ||
case CIRGenAction::OutputType::EmitBC: | ||
case CIRGenAction::OutputType::EmitObj: | ||
|
@@ -124,6 +135,8 @@ getOutputStream(CompilerInstance &CI, StringRef InFile, | |
return CI.createDefaultOutputFile(false, InFile, "s"); | ||
case CIRGenAction::OutputType::EmitCIR: | ||
return CI.createDefaultOutputFile(false, InFile, "cir"); | ||
case CIRGenAction::OutputType::EmitMLIR: | ||
return CI.createDefaultOutputFile(false, InFile, "mlir"); | ||
case CIRGenAction::OutputType::EmitLLVM: | ||
return CI.createDefaultOutputFile(false, InFile, "ll"); | ||
case CIRGenAction::OutputType::EmitBC: | ||
|
@@ -155,6 +168,10 @@ void EmitCIRAction::anchor() {} | |
EmitCIRAction::EmitCIRAction(mlir::MLIRContext *MLIRCtx) | ||
: CIRGenAction(OutputType::EmitCIR, MLIRCtx) {} | ||
|
||
void EmitMLIRAction::anchor() {} | ||
EmitMLIRAction::EmitMLIRAction(mlir::MLIRContext *MLIRCtx) | ||
: CIRGenAction(OutputType::EmitMLIR, MLIRCtx) {} | ||
|
||
void EmitLLVMAction::anchor() {} | ||
EmitLLVMAction::EmitLLVMAction(mlir::MLIRContext *MLIRCtx) | ||
: CIRGenAction(OutputType::EmitLLVM, MLIRCtx) {} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
add_subdirectory(DirectToLLVM) | ||
add_subdirectory(ThroughMLIR) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
set(LLVM_LINK_COMPONENTS | ||
Core | ||
Support | ||
) | ||
|
||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) | ||
|
||
add_clang_library(clangCIRLoweringThroughMLIR | ||
LowerCIRToMLIR.cpp | ||
|
||
DEPENDS | ||
LINK_LIBS | ||
MLIRIR | ||
${dialect_libs} | ||
MLIRCIR | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
//====- LowerCIRToMLIR.cpp - Lowering from CIR to MLIR --------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements lowering of CIR operations to MLIR. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/IR/BuiltinDialect.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "clang/CIR/Dialect/IR/CIRDialect.h" | ||
#include "clang/CIR/Dialect/IR/CIRTypes.h" | ||
#include "clang/CIR/LowerToLLVM.h" | ||
#include "clang/CIR/MissingFeatures.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/TimeProfiler.h" | ||
|
||
using namespace cir; | ||
using namespace llvm; | ||
|
||
namespace cir { | ||
|
||
struct ConvertCIRToMLIRPass | ||
: public mlir::PassWrapper<ConvertCIRToMLIRPass, | ||
mlir::OperationPass<mlir::ModuleOp>> { | ||
void getDependentDialects(mlir::DialectRegistry ®istry) const override { | ||
registry.insert<mlir::BuiltinDialect, mlir::memref::MemRefDialect>(); | ||
} | ||
void runOnOperation() final; | ||
|
||
StringRef getDescription() const override { | ||
return "Convert the CIR dialect module to MLIR standard dialects"; | ||
} | ||
|
||
StringRef getArgument() const override { return "cir-to-mlir"; } | ||
}; | ||
|
||
class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> { | ||
public: | ||
using OpConversionPattern<cir::GlobalOp>::OpConversionPattern; | ||
mlir::LogicalResult | ||
matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor, | ||
mlir::ConversionPatternRewriter &rewriter) const override { | ||
auto moduleOp = op->getParentOfType<mlir::ModuleOp>(); | ||
if (!moduleOp) | ||
return mlir::failure(); | ||
|
||
mlir::OpBuilder b(moduleOp.getContext()); | ||
|
||
const auto cirSymType = op.getSymType(); | ||
assert(!cir::MissingFeatures::convertTypeForMemory()); | ||
auto convertedType = getTypeConverter()->convertType(cirSymType); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please spell out the types... |
||
if (!convertedType) | ||
return mlir::failure(); | ||
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType); | ||
AaronBallman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!memrefType) | ||
memrefType = mlir::MemRefType::get({}, convertedType); | ||
// Add an optional alignment to the global memref. | ||
assert(!cir::MissingFeatures::opGlobalAlignment()); | ||
mlir::IntegerAttr memrefAlignment = mlir::IntegerAttr(); | ||
// Add an optional initial value to the global memref. | ||
mlir::Attribute initialValue = mlir::Attribute(); | ||
std::optional<mlir::Attribute> init = op.getInitialValue(); | ||
if (init.has_value()) { | ||
initialValue = | ||
llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value()) | ||
.Case<cir::IntAttr>([&](cir::IntAttr attr) { | ||
auto rtt = mlir::RankedTensorType::get({}, convertedType); | ||
return mlir::DenseIntElementsAttr::get(rtt, attr.getValue()); | ||
}) | ||
.Case<cir::FPAttr>([&](cir::FPAttr attr) { | ||
auto rtt = mlir::RankedTensorType::get({}, convertedType); | ||
return mlir::DenseFPElementsAttr::get(rtt, attr.getValue()); | ||
}) | ||
.Default([&](mlir::Attribute attr) { | ||
llvm_unreachable("GlobalOp lowering with initial value is not " | ||
"fully supported yet"); | ||
return mlir::Attribute(); | ||
}); | ||
} | ||
|
||
// Add symbol visibility | ||
assert(!cir::MissingFeatures::opGlobalLinkage()); | ||
std::string symVisibility = "public"; | ||
|
||
assert(!cir::MissingFeatures::opGlobalConstant()); | ||
bool isConstant = false; | ||
|
||
rewriter.replaceOpWithNewOp<mlir::memref::GlobalOp>( | ||
op, b.getStringAttr(op.getSymName()), | ||
/*sym_visibility=*/b.getStringAttr(symVisibility), | ||
/*type=*/memrefType, initialValue, | ||
/*constant=*/isConstant, | ||
/*alignment=*/memrefAlignment); | ||
|
||
return mlir::success(); | ||
} | ||
}; | ||
|
||
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, | ||
mlir::TypeConverter &converter) { | ||
patterns.add<CIRGlobalOpLowering>(converter, patterns.getContext()); | ||
} | ||
|
||
static mlir::TypeConverter prepareTypeConverter() { | ||
mlir::TypeConverter converter; | ||
converter.addConversion([&](cir::PointerType type) -> mlir::Type { | ||
assert(!cir::MissingFeatures::convertTypeForMemory()); | ||
mlir::Type ty = converter.convertType(type.getPointee()); | ||
// FIXME: The pointee type might not be converted (e.g. struct) | ||
if (!ty) | ||
return nullptr; | ||
return mlir::MemRefType::get({}, ty); | ||
}); | ||
converter.addConversion( | ||
[&](mlir::IntegerType type) -> mlir::Type { return type; }); | ||
converter.addConversion( | ||
[&](mlir::FloatType type) -> mlir::Type { return type; }); | ||
converter.addConversion([&](cir::VoidType type) -> mlir::Type { return {}; }); | ||
converter.addConversion([&](cir::IntType type) -> mlir::Type { | ||
// arith dialect ops doesn't take signed integer -- drop cir sign here | ||
return mlir::IntegerType::get( | ||
type.getContext(), type.getWidth(), | ||
mlir::IntegerType::SignednessSemantics::Signless); | ||
}); | ||
converter.addConversion([&](cir::SingleType type) -> mlir::Type { | ||
return mlir::Float32Type::get(type.getContext()); | ||
}); | ||
converter.addConversion([&](cir::DoubleType type) -> mlir::Type { | ||
return mlir::Float64Type::get(type.getContext()); | ||
}); | ||
converter.addConversion([&](cir::FP80Type type) -> mlir::Type { | ||
return mlir::Float80Type::get(type.getContext()); | ||
}); | ||
converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type { | ||
return converter.convertType(type.getUnderlying()); | ||
}); | ||
converter.addConversion([&](cir::FP128Type type) -> mlir::Type { | ||
return mlir::Float128Type::get(type.getContext()); | ||
}); | ||
converter.addConversion([&](cir::FP16Type type) -> mlir::Type { | ||
return mlir::Float16Type::get(type.getContext()); | ||
}); | ||
converter.addConversion([&](cir::BF16Type type) -> mlir::Type { | ||
return mlir::BFloat16Type::get(type.getContext()); | ||
}); | ||
|
||
return converter; | ||
} | ||
|
||
void ConvertCIRToMLIRPass::runOnOperation() { | ||
auto module = getOperation(); | ||
|
||
auto converter = prepareTypeConverter(); | ||
|
||
mlir::RewritePatternSet patterns(&getContext()); | ||
|
||
populateCIRToMLIRConversionPatterns(patterns, converter); | ||
|
||
mlir::ConversionTarget target(getContext()); | ||
target.addLegalOp<mlir::ModuleOp>(); | ||
target.addLegalDialect<mlir::memref::MemRefDialect>(); | ||
target.addIllegalDialect<cir::CIRDialect>(); | ||
|
||
if (failed(applyPartialConversion(module, target, std::move(patterns)))) | ||
signalPassFailure(); | ||
} | ||
|
||
std::unique_ptr<mlir::Pass> createConvertCIRToMLIRPass() { | ||
return std::make_unique<ConvertCIRToMLIRPass>(); | ||
} | ||
|
||
mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp mlirModule, | ||
mlir::MLIRContext &mlirCtx) { | ||
llvm::TimeTraceScope scope("Lower CIR To MLIR"); | ||
|
||
mlir::PassManager pm(&mlirCtx); | ||
|
||
pm.addPass(createConvertCIRToMLIRPass()); | ||
|
||
auto result = !mlir::failed(pm.run(mlirModule)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please spell out the type. |
||
if (!result) | ||
llvm::report_fatal_error( | ||
"The pass manager failed to lower CIR to MLIR standard dialects!"); | ||
|
||
// Now that we ran all the lowering passes, verify the final output. | ||
if (mlirModule.verify().failed()) | ||
llvm::report_fatal_error( | ||
"Verification of the final MLIR in standard dialects failed!"); | ||
|
||
return mlirModule; | ||
} | ||
|
||
} // namespace cir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One overall concern I have is that we're adding an entirely new interface to Clang where we could be introducing const-correctness from the start... but we're not doing it. That's not specific to this PR, but it is something I think should be addressed early on. Retrofitting const-correctness is hard and I realize this is touching other interfaces like MLIR and LLVM, but because this is a brand new component, we should be adding const correctness from the start everywhere possible. (IMO, if MLIR lacks const correctness, that's a problem that should be addressed in MLIR given it's also a new interface, comparatively speaking.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLIR design is to literally never use const on its primitives: https://mlir.llvm.org/docs/Rationale/UsageOfConst/ This will probably fight back trying to be const correct only on the clang side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, that's disappointing to say the least. My point still stands -- Clang has been striving to improve const correctness over time and this is a new interface where we should not have to try to retrofit const correctness in the future IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so to stay consistent with the MLIR design philosophy, we won't use const with MLIR operations or values, but we should be moving everything else to const-correctness as it is upstreamed. Does that sound reasonable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that sounds like a good path forward. FWIW, I think it's reasonable for CIR to have some ugly
const_cast
uses to hide the lack of const correctness from Clang, but not to an obnoxious amount (which it sounds like would be the case with MLIR operations or values). Thanks!