Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Initial implementation of lowering CIR to MLIR #127835

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/CIRGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class CIRGenerator : public clang::ASTConsumer {
void Initialize(clang::ASTContext &astContext) override;
bool HandleTopLevelDecl(clang::DeclGroupRef group) override;
mlir::ModuleOp getModule() const;
mlir::MLIRContext &getMLIRContext() { return *mlirContext; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One overall concern I have is that we're adding an entirely new interface to Clang where we could be introducing const-correctness from the start... but we're not doing it. That's not specific to this PR, but it is something I think should be addressed early on. Retrofitting const-correctness is hard and I realize this is touching other interfaces like MLIR and LLVM, but because this is a brand new component, we should be adding const correctness from the start everywhere possible. (IMO, if MLIR lacks const correctness, that's a problem that should be addressed in MLIR given it's also a new interface, comparatively speaking.)

Copy link
Contributor

@xlauko xlauko Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR design is to literally never use const on its primitives: https://mlir.llvm.org/docs/Rationale/UsageOfConst/ This will probably fight back trying to be const correct only on the clang side.

Copy link
Collaborator

@AaronBallman AaronBallman Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that's disappointing to say the least. My point still stands -- Clang has been striving to improve const correctness over time and this is a new interface where we should not have to try to retrofit const correctness in the future IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so to stay consistent with the MLIR design philosophy, we won't use const with MLIR operations or values, but we should be moving everything else to const-correctness as it is upstreamed. Does that sound reasonable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds like a good path forward. FWIW, I think it's reasonable for CIR to have some ugly const_cast uses to hide the lack of const correctness from Clang, but not to an obnoxious amount (which it sounds like would be the case with MLIR operations or values). Thanks!

const mlir::MLIRContext &getMLIRContext() const { return *mlirContext; }
};

} // namespace cir
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/CIR/FrontendAction/CIRGenAction.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
EmitCIR,
EmitLLVM,
EmitBC,
EmitMLIR,
EmitObj,
};

Expand Down Expand Up @@ -59,6 +60,13 @@ class EmitCIRAction : public CIRGenAction {
EmitCIRAction(mlir::MLIRContext *MLIRCtx = nullptr);
};

class EmitMLIRAction : public CIRGenAction {
virtual void anchor();

public:
EmitMLIRAction(mlir::MLIRContext *MLIRCtx = nullptr);
};

class EmitLLVMAction : public CIRGenAction {
virtual void anchor();

Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/CIR/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Module;
} // namespace llvm

namespace mlir {
class MLIRContext;
class ModuleOp;
} // namespace mlir

Expand All @@ -30,6 +31,9 @@ std::unique_ptr<llvm::Module>
lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule,
llvm::LLVMContext &llvmCtx);
} // namespace direct

mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp mlirModule,
mlir::MLIRContext &mlirCtx);
} // namespace cir

#endif // CLANG_CIR_LOWERTOLLVM_H
7 changes: 7 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,13 @@ defm clangir : BoolFOption<"clangir",
BothFlags<[], [ClangOption, CC1Option], "">>;
def emit_cir : Flag<["-"], "emit-cir">, Visibility<[ClangOption, CC1Option]>,
Group<Action_Group>, HelpText<"Build ASTs and then lower to ClangIR">;
def emit_mlir_EQ : Joined<["-"], "emit-mlir=">, Visibility<[CC1Option]>, Group<Action_Group>,
HelpText<"Build ASTs and then generate/lower to the selected MLIR dialect, emit the .mlir or .cir file. "
"Allowed values are `core` for MLIR core dialects and `cir` for ClangIR">,
Values<"core,cir">,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In CIR terms, considering we want to cover DirectToLLVM, ThroughMLIR (core) and CIR, shouldn't this be 3 state: core,llvm,cir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually we will also want to put cir-flat in here, and maybe other variations, but right now these two are the only ones implemented upstream. Technically, you can also lower to the LLVM dialect, but since that's only just now being implemented in the incubator as an output, it isn't done here yet.

NormalizedValuesScope<"clang::frontend">,
NormalizedValues<["MLIR_Core", "MLIR_CIR"]>,
MarshallingInfoEnum<FrontendOpts<"MLIRTargetDialect">, "MLIR_CIR">;
/// ClangIR-specific options - END

def flto_EQ : Joined<["-"], "flto=">,
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Frontend/FrontendOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ enum ActionKind {
/// Emit a .cir file
EmitCIR,

/// Emit a .mlir file
EmitMLIR,

/// Emit a .ll file.
EmitLLVM,

Expand Down Expand Up @@ -148,6 +151,8 @@ enum ActionKind {
PrintDependencyDirectivesSourceMinimizerOutput
};

enum MLIRDialectKind { MLIR_CIR, MLIR_Core };

} // namespace frontend

/// The kind of a file that we've been handed as an input.
Expand Down Expand Up @@ -417,6 +422,8 @@ class FrontendOptions {
/// Specifies the output format of the AST.
ASTDumpOutputFormat ASTDumpFormat = ADOF_Default;

frontend::MLIRDialectKind MLIRTargetDialect = frontend::MLIR_CIR;

/// The input kind, either specified via -x argument or deduced from the input
/// file name.
InputKind DashX;
Expand Down
30 changes: 27 additions & 3 deletions clang/lib/CIR/FrontendAction/CIRGenAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CodeGen/BackendUtil.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendOptions.h"
#include "llvm/IR/Module.h"

using namespace cir;
Expand All @@ -24,6 +25,7 @@ static BackendAction
getBackendActionFromOutputType(CIRGenAction::OutputType Action) {
switch (Action) {
case CIRGenAction::OutputType::EmitCIR:
case CIRGenAction::OutputType::EmitMLIR:
assert(false &&
"Unsupported output type for getBackendActionFromOutputType!");
break; // Unreachable, but fall through to report that
Expand Down Expand Up @@ -82,14 +84,30 @@ class CIRGenConsumer : public clang::ASTConsumer {
void HandleTranslationUnit(ASTContext &C) override {
Gen->HandleTranslationUnit(C);
mlir::ModuleOp MlirModule = Gen->getModule();
mlir::MLIRContext &MlirCtx = Gen->getMLIRContext();
switch (Action) {
case CIRGenAction::OutputType::EmitCIR:
if (OutputStream && MlirModule) {
assert(CI.getFrontendOpts().MLIRTargetDialect == frontend::MLIR_CIR);
case CIRGenAction::OutputType::EmitMLIR: {
switch (CI.getFrontendOpts().MLIRTargetDialect) {
case frontend::MLIR_CIR:
if (OutputStream && MlirModule) {
mlir::OpPrintingFlags Flags;
Flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/false);
MlirModule->print(*OutputStream, Flags);
}
break;
case frontend::MLIR_Core:
mlir::ModuleOp LoweredMlirModule =
lowerFromCIRToMLIR(MlirModule, MlirCtx);
assert(OutputStream && "No output stream when lowering to MLIR!");
// FIXME: we cannot roundtrip prettyForm=true right now.
mlir::OpPrintingFlags Flags;
Flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/false);
MlirModule->print(*OutputStream, Flags);
LoweredMlirModule->print(*OutputStream, Flags);
break;
}
break;
}
case CIRGenAction::OutputType::EmitLLVM:
case CIRGenAction::OutputType::EmitBC:
case CIRGenAction::OutputType::EmitObj:
Expand Down Expand Up @@ -124,6 +142,8 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
return CI.createDefaultOutputFile(false, InFile, "s");
case CIRGenAction::OutputType::EmitCIR:
return CI.createDefaultOutputFile(false, InFile, "cir");
case CIRGenAction::OutputType::EmitMLIR:
return CI.createDefaultOutputFile(false, InFile, "mlir");
case CIRGenAction::OutputType::EmitLLVM:
return CI.createDefaultOutputFile(false, InFile, "ll");
case CIRGenAction::OutputType::EmitBC:
Expand Down Expand Up @@ -155,6 +175,10 @@ void EmitCIRAction::anchor() {}
EmitCIRAction::EmitCIRAction(mlir::MLIRContext *MLIRCtx)
: CIRGenAction(OutputType::EmitCIR, MLIRCtx) {}

void EmitMLIRAction::anchor() {}
EmitMLIRAction::EmitMLIRAction(mlir::MLIRContext *MLIRCtx)
: CIRGenAction(OutputType::EmitMLIR, MLIRCtx) {}

void EmitLLVMAction::anchor() {}
EmitLLVMAction::EmitLLVMAction(mlir::MLIRContext *MLIRCtx)
: CIRGenAction(OutputType::EmitLLVM, MLIRCtx) {}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/FrontendAction/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_clang_library(clangCIRFrontendAction
clangFrontend
clangCIR
clangCIRLoweringDirectToLLVM
clangCIRLoweringThroughMLIR
clangCodeGen
MLIRCIR
MLIRIR
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Lowering/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(DirectToLLVM)
add_subdirectory(ThroughMLIR)
16 changes: 16 additions & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
set(LLVM_LINK_COMPONENTS
Core
Support
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

add_clang_library(clangCIRLoweringThroughMLIR
LowerCIRToMLIR.cpp

DEPENDS
LINK_LIBS
MLIRIR
${dialect_libs}
MLIRCIR
)
201 changes: 201 additions & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//====- LowerCIRToMLIR.cpp - Lowering from CIR to MLIR --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of CIR operations to MLIR.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToLLVM.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/TimeProfiler.h"

using namespace cir;
using namespace llvm;

namespace cir {

struct ConvertCIRToMLIRPass
: public mlir::PassWrapper<ConvertCIRToMLIRPass,
mlir::OperationPass<mlir::ModuleOp>> {
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::BuiltinDialect, mlir::memref::MemRefDialect>();
}
void runOnOperation() final;

StringRef getDescription() const override {
return "Convert the CIR dialect module to MLIR standard dialects";
}

StringRef getArgument() const override { return "cir-to-mlir"; }
};

class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
public:
using OpConversionPattern<cir::GlobalOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(cir::GlobalOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ModuleOp moduleOp = op->getParentOfType<mlir::ModuleOp>();
if (!moduleOp)
return mlir::failure();

mlir::OpBuilder b(moduleOp.getContext());

const mlir::Type cirSymType = op.getSymType();
assert(!cir::MissingFeatures::convertTypeForMemory());
mlir::Type convertedType = getTypeConverter()->convertType(cirSymType);
if (!convertedType)
return mlir::failure();
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
if (!memrefType)
memrefType = mlir::MemRefType::get({}, convertedType);
// Add an optional alignment to the global memref.
assert(!cir::MissingFeatures::opGlobalAlignment());
mlir::IntegerAttr memrefAlignment = mlir::IntegerAttr();
// Add an optional initial value to the global memref.
mlir::Attribute initialValue = mlir::Attribute();
std::optional<mlir::Attribute> init = op.getInitialValue();
if (init.has_value()) {
initialValue =
llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
.Case<cir::IntAttr>([&](cir::IntAttr attr) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
return mlir::DenseIntElementsAttr::get(rtt, attr.getValue());
})
.Case<cir::FPAttr>([&](cir::FPAttr attr) {
auto rtt = mlir::RankedTensorType::get({}, convertedType);
return mlir::DenseFPElementsAttr::get(rtt, attr.getValue());
})
.Default([&](mlir::Attribute attr) {
llvm_unreachable("GlobalOp lowering with initial value is not "
"fully supported yet");
return mlir::Attribute();
});
}

// Add symbol visibility
assert(!cir::MissingFeatures::opGlobalLinkage());
std::string symVisibility = "public";

assert(!cir::MissingFeatures::opGlobalConstant());
bool isConstant = false;

rewriter.replaceOpWithNewOp<mlir::memref::GlobalOp>(
op, b.getStringAttr(op.getSymName()),
/*sym_visibility=*/b.getStringAttr(symVisibility),
/*type=*/memrefType, initialValue,
/*constant=*/isConstant,
/*alignment=*/memrefAlignment);

return mlir::success();
}
};

void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRGlobalOpLowering>(converter, patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
mlir::TypeConverter converter;
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
assert(!cir::MissingFeatures::convertTypeForMemory());
mlir::Type ty = converter.convertType(type.getPointee());
// FIXME: The pointee type might not be converted (e.g. struct)
if (!ty)
return nullptr;
return mlir::MemRefType::get({}, ty);
});
converter.addConversion(
[&](mlir::IntegerType type) -> mlir::Type { return type; });
converter.addConversion(
[&](mlir::FloatType type) -> mlir::Type { return type; });
converter.addConversion([&](cir::VoidType type) -> mlir::Type { return {}; });
converter.addConversion([&](cir::IntType type) -> mlir::Type {
// arith dialect ops doesn't take signed integer -- drop cir sign here
return mlir::IntegerType::get(
type.getContext(), type.getWidth(),
mlir::IntegerType::SignednessSemantics::Signless);
});
converter.addConversion([&](cir::SingleType type) -> mlir::Type {
return mlir::Float32Type::get(type.getContext());
});
converter.addConversion([&](cir::DoubleType type) -> mlir::Type {
return mlir::Float64Type::get(type.getContext());
});
converter.addConversion([&](cir::FP80Type type) -> mlir::Type {
return mlir::Float80Type::get(type.getContext());
});
converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type {
return converter.convertType(type.getUnderlying());
});
converter.addConversion([&](cir::FP128Type type) -> mlir::Type {
return mlir::Float128Type::get(type.getContext());
});
converter.addConversion([&](cir::FP16Type type) -> mlir::Type {
return mlir::Float16Type::get(type.getContext());
});
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
return mlir::BFloat16Type::get(type.getContext());
});

return converter;
}

void ConvertCIRToMLIRPass::runOnOperation() {
mlir::ModuleOp module = getOperation();

mlir::TypeConverter converter = prepareTypeConverter();

mlir::RewritePatternSet patterns(&getContext());

populateCIRToMLIRConversionPatterns(patterns, converter);

mlir::ConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
target.addLegalDialect<mlir::memref::MemRefDialect>();
target.addIllegalDialect<cir::CIRDialect>();

if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}

std::unique_ptr<mlir::Pass> createConvertCIRToMLIRPass() {
return std::make_unique<ConvertCIRToMLIRPass>();
}

mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp mlirModule,
mlir::MLIRContext &mlirCtx) {
llvm::TimeTraceScope scope("Lower CIR To MLIR");

mlir::PassManager pm(&mlirCtx);

pm.addPass(createConvertCIRToMLIRPass());

bool result = !mlir::failed(pm.run(mlirModule));
if (!result)
llvm::report_fatal_error(
"The pass manager failed to lower CIR to MLIR standard dialects!");

// Now that we ran all the lowering passes, verify the final output.
if (mlirModule.verify().failed())
llvm::report_fatal_error(
"Verification of the final MLIR in standard dialects failed!");

return mlirModule;
}

} // namespace cir
Loading