From fd32cf987f381bb68ffde962be1e07ca2c6d5512 Mon Sep 17 00:00:00 2001 From: lipracer Date: Fri, 29 Mar 2024 23:25:07 +0800 Subject: [PATCH 1/3] [mlir] fix Undefined behavior in CastInfo::castFailed with From= Fixes https://github.com/llvm/llvm-project/issues/86647 --- mlir/include/mlir/IR/OpDefinition.h | 29 ++++++++++++++++ mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +- mlir/unittests/IR/InterfaceTest.cpp | 40 ++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 59f094d6690991..5610daadfbecb5 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -22,6 +22,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/ODSSupport.h" #include "mlir/IR/Operation.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include @@ -2142,6 +2143,34 @@ struct DenseMapInfo +struct CastInfo< + To, From, + std::enable_if_t< + std::is_base_of_v, + To> && + std::is_base_of_v, + typename std::remove_const_t< + From>::InterfaceTraits>, + std::remove_const_t>, + void>> : NullableValueCastFailed, + DefaultDoCastIfPossible> { + + static bool isPossible(From &val) { + if constexpr (std::is_same_v) + return true; + else + return mlir::OpInterface:: + InterfaceBase::classof( + const_cast &>(val).getOperation()); + } + + static To doCast(From &val) { + return To(const_cast &>(val).getOperation()); + } +}; + } // namespace llvm #endif diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 1f1b1d9a340391..c8233d19da4b05 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -545,7 +545,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" "public:\n" - " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", + " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" + " using InterfaceTraits = detail::{2};\n", interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 42196b003e7dad..741365b3efb5fc 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -17,6 +17,9 @@ #include "../../test/lib/Dialect/Test/TestDialect.h" #include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestTypes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Parser/Parser.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace test; @@ -84,3 +87,40 @@ TEST(InterfaceTest, TestImplicitConversion) { typeA = typeB; EXPECT_EQ(typeA, typeB); } + +TEST(OperationInterfaceTest, CastOpToInterface) { + DialectRegistry registry; + MLIRContext ctx; + + const char *ir = R"MLIR( + func.func @map(%arg : tensor<1xi64>) { + %0 = arith.constant dense<[10]> : tensor<1xi64> + %1 = arith.addi %arg, %0 : tensor<1xi64> + return + } + )MLIR"; + + registry.insert(); + ctx.appendDialectRegistry(registry); + OwningOpRef module = parseSourceString(ir, &ctx); + Operation &op = cast(module->front()).getBody().front().front(); + + OpAsmOpInterface interface = llvm::cast(op); + + bool constantOp = + llvm::TypeSwitch(interface) + .Case([&](auto op) { + return std::is_same_v; + }); + + EXPECT_TRUE(constantOp); + + EXPECT_FALSE(llvm::isa(interface)); + EXPECT_FALSE(llvm::dyn_cast(interface)); + + EXPECT_TRUE(llvm::isa(interface)); + EXPECT_TRUE(llvm::dyn_cast(interface)); + + EXPECT_TRUE(llvm::isa(interface)); + EXPECT_TRUE(llvm::dyn_cast(interface)); +} From cc59eaca3b16fcccbf6b806037a70a03f02b6767 Mon Sep 17 00:00:00 2001 From: lipracer Date: Thu, 11 Jul 2024 11:06:45 -0400 Subject: [PATCH 2/3] add CastInfo to support cast Interface to Op --- mlir/include/mlir/IR/OpDefinition.h | 53 +++++++++++++++++++++++++++-- mlir/include/mlir/TableGen/Class.h | 2 ++ mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++++ mlir/unittests/IR/InterfaceTest.cpp | 19 +++++++---- 4 files changed, 75 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 5610daadfbecb5..763d0f6ff2cb8d 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -2144,6 +2144,9 @@ struct DenseMapInfo struct CastInfo< To, From, @@ -2157,7 +2160,7 @@ struct CastInfo< void>> : NullableValueCastFailed, DefaultDoCastIfPossible> { - static bool isPossible(From &val) { + static inline bool isPossible(From &val) { if constexpr (std::is_same_v) return true; else @@ -2166,7 +2169,53 @@ struct CastInfo< const_cast &>(val).getOperation()); } - static To doCast(From &val) { + static inline To doCast(From &val) { + return To(const_cast &>(val).getOperation()); + } +}; + +template +struct is_concrete_op_type : public std::false_type {}; + +template typename... Traits> +constexpr auto concrete_op_base_type_impl(std::tuple...>) { + return mlir::Op(nullptr); +} + +template +using concrete_op_base_type = + decltype(concrete_op_base_type_impl(typename OpT::traits())); + +template +struct is_concrete_op_type< + OpT, std::enable_if_t, OpT>>> + : public std::true_type {}; + +/// Add support for llvm style casts. +/// We provide a cast between To and From if To is mlir::Op or derives from it and From is mlir::OpInterface or +/// derives from it. +template +struct CastInfo< + To, From, + std::enable_if_t< + is_concrete_op_type() && + std::is_base_of_v, + typename std::remove_const_t< + From>::InterfaceTraits>, + std::remove_const_t>>> + : NullableValueCastFailed, + DefaultDoCastIfPossible> { + + static inline bool isPossible(From &val) { + if constexpr (std::is_same_v) + return true; + else + return isa( + const_cast &>(val).getOperation()); + } + + static inline To doCast(From &val) { return To(const_cast &>(val).getOperation()); } }; diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index f750a34a3b2ba4..5cb9aa4e6d21ba 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -521,6 +521,8 @@ class ParentClass { /// Write the parent class declaration. void writeTo(raw_indented_ostream &os) const; + friend class OpClass; + private: /// The fully resolved C++ name of the parent class. std::string name; diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp index 60fa1833ce625e..5426302dfed3e3 100644 --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration, } void OpClass::finalize() { + std::string traitList; + llvm::raw_string_ostream os(traitList); + iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1, + std::end(parent.templateParams)); + llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) { + os << trait << "<" << getClassName().str() << ">"; + }); + declare("traits", "std::tuple<" + traitList + ">"); Class::finalize(); + declare(Visibility::Public); declare(extraClassDeclaration, extraClassDefinition); } diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 741365b3efb5fc..6c983385679b18 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -88,7 +88,7 @@ TEST(InterfaceTest, TestImplicitConversion) { EXPECT_EQ(typeA, typeB); } -TEST(OperationInterfaceTest, CastOpToInterface) { +TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) { DialectRegistry registry; MLIRContext ctx; @@ -105,13 +105,20 @@ TEST(OperationInterfaceTest, CastOpToInterface) { OwningOpRef module = parseSourceString(ir, &ctx); Operation &op = cast(module->front()).getBody().front().front(); + static_assert(std::is_base_of_v, + arith::AddIOp>, + ""); + static_assert(llvm::is_concrete_op_type(), ""); + static_assert(!llvm::is_concrete_op_type(), ""); + OpAsmOpInterface interface = llvm::cast(op); - bool constantOp = - llvm::TypeSwitch(interface) - .Case([&](auto op) { - return std::is_same_v; - }); + bool constantOp = llvm::TypeSwitch(interface) + .Case([&](auto op) { + bool is_same = + std::is_same_v; + return is_same; + }); EXPECT_TRUE(constantOp); From 11f7d95c29ae74106a05d078d72ccc3833d763f7 Mon Sep 17 00:00:00 2001 From: lipracer Date: Mon, 4 Nov 2024 07:37:40 -0500 Subject: [PATCH 3/3] refine --- mlir/include/mlir/IR/OpDefinition.h | 73 ++++++---------------- mlir/include/mlir/TableGen/Class.h | 2 - mlir/tools/mlir-tblgen/OpClass.cpp | 9 --- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +- mlir/unittests/IR/InterfaceTest.cpp | 7 --- 5 files changed, 19 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 763d0f6ff2cb8d..7bfdbe09ed961b 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -2144,66 +2144,29 @@ struct DenseMapInfo -struct CastInfo< - To, From, - std::enable_if_t< - std::is_base_of_v, - To> && - std::is_base_of_v, - typename std::remove_const_t< - From>::InterfaceTraits>, - std::remove_const_t>, - void>> : NullableValueCastFailed, - DefaultDoCastIfPossible> { - - static inline bool isPossible(From &val) { - if constexpr (std::is_same_v) - return true; - else - return mlir::OpInterface:: - InterfaceBase::classof( - const_cast &>(val).getOperation()); - } - - static inline To doCast(From &val) { - return To(const_cast &>(val).getOperation()); - } +/// The reason why these conditions are not directly used in specialized +/// parameters is that some compilers do not support short circuits between +/// several conditions. +template > +struct is_complete_and_derive_from_state { + constexpr static bool value = std::is_base_of_v; }; -template -struct is_concrete_op_type : public std::false_type {}; - -template typename... Traits> -constexpr auto concrete_op_base_type_impl(std::tuple...>) { - return mlir::Op(nullptr); -} - -template -using concrete_op_base_type = - decltype(concrete_op_base_type_impl(typename OpT::traits())); - -template -struct is_concrete_op_type< - OpT, std::enable_if_t, OpT>>> - : public std::true_type {}; +template +struct is_complete_and_derive_from_state { + constexpr static bool value = false; +}; /// Add support for llvm style casts. -/// We provide a cast between To and From if To is mlir::Op or derives from it and From is mlir::OpInterface or -/// derives from it. +/// We provide a cast between To and From if To and From is mlir::OpState or +/// derives from it. To avoid some pre declared types matching here, we have +/// added a condition for whether there is a complete type defintion. template -struct CastInfo< - To, From, - std::enable_if_t< - is_concrete_op_type() && - std::is_base_of_v, - typename std::remove_const_t< - From>::InterfaceTraits>, - std::remove_const_t>>> +struct CastInfo::value && + is_complete_and_derive_from_state< + std::remove_const_t>::value, + void>> : NullableValueCastFailed, DefaultDoCastIfPossible> { diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h index 5cb9aa4e6d21ba..f750a34a3b2ba4 100644 --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -521,8 +521,6 @@ class ParentClass { /// Write the parent class declaration. void writeTo(raw_indented_ostream &os) const; - friend class OpClass; - private: /// The fully resolved C++ name of the parent class. std::string name; diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp index 5426302dfed3e3..60fa1833ce625e 100644 --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -36,16 +36,7 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration, } void OpClass::finalize() { - std::string traitList; - llvm::raw_string_ostream os(traitList); - iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1, - std::end(parent.templateParams)); - llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) { - os << trait << "<" << getClassName().str() << ">"; - }); - declare("traits", "std::tuple<" + traitList + ">"); Class::finalize(); - declare(Visibility::Public); declare(extraClassDeclaration, extraClassDefinition); } diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index c8233d19da4b05..1f1b1d9a340391 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -545,8 +545,7 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" "public:\n" - " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" - " using InterfaceTraits = detail::{2};\n", + " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 6c983385679b18..73d43e87556380 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -104,13 +104,6 @@ TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) { ctx.appendDialectRegistry(registry); OwningOpRef module = parseSourceString(ir, &ctx); Operation &op = cast(module->front()).getBody().front().front(); - - static_assert(std::is_base_of_v, - arith::AddIOp>, - ""); - static_assert(llvm::is_concrete_op_type(), ""); - static_assert(!llvm::is_concrete_op_type(), ""); - OpAsmOpInterface interface = llvm::cast(op); bool constantOp = llvm::TypeSwitch(interface)