Skip to content

Commit

Permalink
[flang] Remove materialization workaround in type converter
Browse files Browse the repository at this point in the history
This change is in preparation of #97903, which adds extra checks for materializations: it is now enforced that they produce an SSA value of the correct type, so the current workaround no longer works.

For `fir.has_value` the fix is simple: no target materializations on the operands are performed if the lowering patterns is initialized without a type converter. For `cg::XEmboxOp`, the existing workaround that skips `unrealized_conversion_cast` ops can be generalized. (This is still a workaround.)

Also remove the lowering pattern for `unrealized_conversion_cast`. This pattern has no effect because `unrealized_conversion_cast` ops that are inserted by the dialect conversion framework are never matched by the pattern driver.
  • Loading branch information
matthias-springer committed Jul 13, 2024
1 parent 3b7a7f4 commit b105e0e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 81 deletions.
5 changes: 5 additions & 0 deletions flang/include/flang/Tools/CLOptions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
/// This file defines some shared command-line options that can be used when
/// debugging the test tools. This file must be included into the tool.

#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -223,6 +224,10 @@ inline void addFIRToLLVMPass(
options.forceUnifiedTBAATree = useOldAliasTags;
addPassConditionally(pm, disableFirToLlvmIr,
[&]() { return fir::createFIRToLLVMPass(options); });
// The dialect conversion framework may leave dead unrealized_conversion_cast
// ops behind, so run reconcile-unrealized-casts to clean them up.
addPassConditionally(pm, disableFirToLlvmIr,
[&]() { return mlir::createReconcileUnrealizedCastsPass(); });
}

inline void addLLVMDialectToLLVMPass(
Expand Down
74 changes: 27 additions & 47 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
Expand Down Expand Up @@ -1726,10 +1725,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
// fir.box was translated to an llvm.ptr<llvm.struct<>> and the MLIR pass
// manager inserted a builtin.unrealized_conversion_cast that was inserted
// and needs to be removed here.
if (isInGlobalOp(rewriter))
if (auto unrealizedCast =
loweredBox.getDefiningOp<mlir::UnrealizedConversionCastOp>())
loweredBox = unrealizedCast.getInputs()[0];
if (auto unrealizedCast =
loweredBox.getDefiningOp<mlir::UnrealizedConversionCastOp>())
loweredBox = unrealizedCast.getInputs()[0];

TypePair inputBoxTyPair = getBoxTypePair(rebox.getBox().getType());

Expand Down Expand Up @@ -2042,13 +2040,13 @@ struct ExtractValueOpConversion
/// InsertValue is the generalized instruction for the composition of new
/// aggregate type values.
struct InsertValueOpConversion
: public fir::FIROpAndTypeConversion<fir::InsertValueOp>,
: public mlir::OpConversionPattern<fir::InsertValueOp>,
public ValueOpCommon {
using FIROpAndTypeConversion::FIROpAndTypeConversion;
using OpConversionPattern::OpConversionPattern;

llvm::LogicalResult
doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
matchAndRewrite(fir::InsertValueOp insertVal, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::ValueRange operands = adaptor.getOperands();
auto indices = collectIndices(rewriter, insertVal.getCoor());
toRowMajor(indices, operands[0].getType());
Expand Down Expand Up @@ -2669,8 +2667,9 @@ struct TypeDescOpConversion : public fir::FIROpConversion<fir::TypeDescOp> {
};

/// Lower `fir.has_value` operation to `llvm.return` operation.
struct HasValueOpConversion : public fir::FIROpConversion<fir::HasValueOp> {
using FIROpConversion::FIROpConversion;
struct HasValueOpConversion
: public mlir::OpConversionPattern<fir::HasValueOp> {
using OpConversionPattern::OpConversionPattern;

llvm::LogicalResult
matchAndRewrite(fir::HasValueOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -3515,29 +3514,6 @@ struct MustBeDeadConversion : public fir::FIROpConversion<FromOp> {
}
};

struct UnrealizedConversionCastOpConversion
: public fir::FIROpConversion<mlir::UnrealizedConversionCastOp> {
using FIROpConversion::FIROpConversion;

llvm::LogicalResult
matchAndRewrite(mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(op.getOutputs().getTypes().size() == 1 && "expect a single type");
mlir::Type convertedType = convertType(op.getOutputs().getTypes()[0]);
if (convertedType == adaptor.getInputs().getTypes()[0]) {
rewriter.replaceOp(op, adaptor.getInputs());
return mlir::success();
}

convertedType = adaptor.getInputs().getTypes()[0];
if (convertedType == op.getOutputs().getType()[0]) {
rewriter.replaceOp(op, adaptor.getInputs());
return mlir::success();
}
return mlir::failure();
}
};

struct ShapeOpConversion : public MustBeDeadConversion<fir::ShapeOp> {
using MustBeDeadConversion::MustBeDeadConversion;
};
Expand Down Expand Up @@ -3714,7 +3690,8 @@ class FIRToLLVMLowering
signalPassFailure();
}

// Run pass to add comdats to functions that have weak linkage on relevant platforms
// Run pass to add comdats to functions that have weak linkage on relevant
// platforms
if (fir::getTargetTriple(mod).supportsCOMDAT()) {
mlir::OpPassManager comdatPM("builtin.module");
comdatPM.addPass(mlir::LLVM::createLLVMAddComdats());
Expand Down Expand Up @@ -3789,16 +3766,19 @@ void fir::populateFIRToLLVMConversionPatterns(
DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
GlobalOpConversion, HasValueOpConversion, InsertOnRangeOpConversion,
InsertValueOpConversion, IsPresentOpConversion, LenParamIndexOpConversion,
LoadOpConversion, MulcOpConversion, NegcOpConversion,
NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion,
SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
StoreOpConversion, StringLitOpConversion, SubcOpConversion,
TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
UnrealizedConversionCastOpConversion, XArrayCoorOpConversion,
XEmboxOpConversion, XReboxOpConversion, ZeroOpConversion>(converter,
options);
GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
SliceOpConversion, StoreOpConversion, StringLitOpConversion,
SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
XReboxOpConversion, ZeroOpConversion>(converter, options);

// Patterns that are populated without a type converter do not trigger
// target materializations for the operands of the root op.
patterns.insert<HasValueOpConversion, InsertValueOpConversion>(
patterns.getContext());
}
34 changes: 0 additions & 34 deletions flang/lib/Optimizer/CodeGen/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,6 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
// Convert it here to i1 just in case it survives.
return mlir::IntegerType::get(&getContext(), 1);
});
// FIXME: https://reviews.llvm.org/D82831 introduced an automatic
// materialization of conversion around function calls that is not working
// well with fir lowering to llvm (incorrect llvm.mlir.cast are inserted).
// Workaround until better analysis: register a handler that does not insert
// any conversions.
addSourceMaterialization(
[&](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> std::optional<mlir::Value> {
if (inputs.size() != 1)
return std::nullopt;
return inputs[0];
});
// Similar FIXME workaround here (needed for compare.fir/select-type.fir
// as well as rebox-global.fir tests). This is needed to cope with the
// the fact that codegen does not lower some operation results to the LLVM
// type produced by this LLVMTypeConverter. For instance, inside FIR
// globals, fir.box are lowered to llvm.struct, while the fir.box type
// conversion translates it into an llvm.ptr<llvm.struct<>> because
// descriptors are manipulated in memory outside of global initializers
// where this is not possible. Hence, MLIR inserts
// builtin.unrealized_conversion_cast after the translation of operations
// producing fir.box in fir.global codegen. addSourceMaterialization and
// addTargetMaterialization allow ignoring these ops and removing them
// after codegen assuming the type discrepencies are intended (like for
// fir.box inside globals).
addTargetMaterialization(
[&](mlir::OpBuilder &builder, mlir::Type resultType,
mlir::ValueRange inputs,
mlir::Location loc) -> std::optional<mlir::Value> {
if (inputs.size() != 1)
return std::nullopt;
return inputs[0];
});
}

// i32 is used here because LLVM wants i32 constants when indexing into struct
Expand Down
1 change: 1 addition & 0 deletions flang/test/Fir/basic-program.fir
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,5 @@ func.func @_QQmain() {
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations eliminated
// PASSES-NEXT: TargetRewrite
// PASSES-NEXT: FIRToLLVMLowering
// PASSES-NEXT: ReconcileUnrealizedCasts
// PASSES-NEXT: LLVMIRLoweringPass

0 comments on commit b105e0e

Please sign in to comment.