Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-springer committed Jul 6, 2024
1 parent 55b95a7 commit cfea4ad
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 54 deletions.
3 changes: 2 additions & 1 deletion mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value.
/// a signature conversion of a single block argument, to a single SSA value
/// with the old argument type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
"apply_conversion_patterns.scf.scf_to_control_flow",
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
let description = [{
Collects patterns that lower structured control flow ops to unstructured
control flow.
}];

let assemblyFormat = "attr-dict";
}

def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;

def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ class TypeConverter {

/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value.
/// a signature conversion of a single block argument, to a single SSA value
/// with the old block argument type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
Expand Down
28 changes: 21 additions & 7 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.isVarArg());
});

// Materialization for memrefs creates descriptor structs from individual
// values constituting them, when descriptors are used, i.e. more than one
// value represents a memref.
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type. The dialect conversion framework will then
// insert a target materialization from the original block argument type to
// a legal type.
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
Expand All @@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
}
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
inputs);
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
resultType, inputs);
// An argument materialization must return a value of type
// `resultType`, so insert a cast from the memref descriptor type
// (!llvm.struct) to the original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
Value desc;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
Expand All @@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return std::nullopt;
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
} else {
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
}
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
// An argument materialization must return a value of type `resultType`,
// so insert a cast from the memref descriptor type (!llvm.struct) to the
// original memref type.
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
.getResult(0);
});
// Add generic source and target materializations to handle cases where
// non-LLVM types persist after an LLVM conversion.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
MLIRIR
MLIRLoopLikeInterface
MLIRSCFDialect
MLIRSCFToControlFlow
MLIRSCFTransforms
MLIRSCFUtils
MLIRTransformDialect
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"

#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
conversionTarget);
}

void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
populateSCFToControlFlowConversionPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//
Expand Down
83 changes: 40 additions & 43 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target,
Type origOutputType = nullptr)
MaterializationKind kind = MaterializationKind::Target)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), origOutputType(origOutputType) {}
converterAndKind(converter, kind) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
Expand All @@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}

/// Return the original illegal output type of the input values.
Type getOrigOutputType() const { return origOutputType; }

private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
converterAndKind;

/// The original output type. This is only used for argument conversions.
Type origOutputType;
};
} // namespace

Expand Down Expand Up @@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Block *insertBlock,
Block::iterator insertPt, Location loc,
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);

Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type origOutputType,
Type outputType,
const TypeConverter *converter);

Expand Down Expand Up @@ -1388,20 +1379,24 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
mapping.map(origArg, newArg);
} else {
Type origOutputType = origArg.getType();

// Legalize the argument output type.
Type outputType = origOutputType;
if (Type legalOutputType = converter->convertType(outputType))
outputType = legalOutputType;

newArg = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
converter);
// Build argument materialization: new block arguments -> old block
// argument type.
Value argMat = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
mapping.map(origArg, argMat);

// Build target materialization: old block argument type -> legal type.
if (Type legalOutputType = converter->convertType(origArg.getType())) {
newArg = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, newArg);
} else {
newArg = argMat;
}
}

mapping.map(origArg, newArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
Expand All @@ -1424,7 +1419,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
Location loc, ValueRange inputs, Type outputType, Type origOutputType,
Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
Expand All @@ -1435,16 +1430,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
OpBuilder builder(insertBlock, insertPt);
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
origOutputType);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
Block *block, Location loc, ValueRange inputs, Type origOutputType,
Type outputType, const TypeConverter *converter) {
Block *block, Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
origOutputType, converter);
converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
Expand All @@ -1456,7 +1450,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(

return buildUnresolvedMaterialization(MaterializationKind::Target,
insertBlock, insertPt, loc, input,
outputType, outputType, converter);
outputType, converter);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2672,19 +2666,28 @@ static void computeNecessaryMaterializations(
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping,
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
// Helper function to check if the given value or a not yet materialized
// replacement of the given value is live.
// Note: `inverseMapping` maps from replaced values to original values.
auto isLive = [&](Value value) {
auto findFn = [&](Operation *user) {
auto matIt = materializationOps.find(user);
if (matIt != materializationOps.end())
return !necessaryMaterializations.count(matIt->second);
return rewriterImpl.isOpIgnored(user);
};
// This value may be replacing another value that has a live user.
for (Value inv : inverseMapping.lookup(value))
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
// A worklist is needed because a value may have gone through a chain of
// replacements and each of the replaced values may have live users.
SmallVector<Value> worklist;
worklist.push_back(value);
while (!worklist.empty()) {
Value next = worklist.pop_back_val();
if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
return true;
// Or have live users itself.
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
// This value may be replacing another value that has a live user.
llvm::append_range(worklist, inverseMapping.lookup(next));
}
return false;
};

llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
Expand Down Expand Up @@ -2844,18 +2847,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
switch (mat.getMaterializationKind()) {
case MaterializationKind::Argument:
// Try to materialize an argument conversion.
// FIXME: The current argument materialization hook expects the original
// output type, even though it doesn't use that as the actual output type
// of the generated IR. The output type is just used as an indicator of
// the type of materialization to do. This behavior is really awkward in
// that it diverges from the behavior of the other hooks, and can be
// easily misunderstood. We should clean up the argument hooks to better
// represent the desired invariants we actually care about.
newMaterialization = converter->materializeArgumentConversion(
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
rewriter, op->getLoc(), outputType, inputOperands);
if (newMaterialization)
break;

// If an argument materialization failed, fallback to trying a target
// materialization.
[[fallthrough]];
Expand All @@ -2865,6 +2860,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
break;
}
if (newMaterialization) {
assert(newMaterialization.getType() == opResult.getType() &&
"materialization callback produced value of incorrect type");
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
inverseMapping);
return success();
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s

// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR

// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR

// These tests were separated from func-memref.mlir because applying
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Transforms/test-block-legalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s

// CHECK-LABEL: func @complex_block_signature_conversion(
// CHECK: %[[cst:.*]] = complex.constant
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
// Note: Some blocks are omitted.
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]]
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
// CHECK: llvm.br ^[[block2:.*]]
// CHECK: ^[[block2]]:
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
func.func @complex_block_signature_conversion() {
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
%true = arith.constant true
%0 = scf.if %true -> complex<f64> {
scf.yield %cst : complex<f64>
} else {
scf.yield %cst : complex<f64>
}

// Regression test to ensure that the a source materialization is inserted.
// The operand of "test.consumer_of_complex" must not change.
"test.consumer_of_complex"(%0) : (complex<f64>) -> ()
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
: (!transform.any_op) -> !transform.any_op
transform.apply_conversion_patterns to %func {
transform.apply_conversion_patterns.dialect_to_llvm "cf"
transform.apply_conversion_patterns.func.func_to_llvm
transform.apply_conversion_patterns.scf.scf_to_control_flow
} with type_converter {
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
} {
legal_dialects = ["llvm"],
partial_conversion
} : !transform.any_op
transform.yield
}
}

0 comments on commit cfea4ad

Please sign in to comment.