-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: lipracer <[email protected]>
- Loading branch information
Showing
1 changed file
with
138 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,40 @@ | ||
commit 3ea3d4bed57c4f6a35bed044bca8c1277fa2bb17 | ||
Author: lipracer <[email protected]> | ||
Date: Fri Mar 29 23:25:07 2024 +0800 | ||
From a2113b34ed4c5bebfc2d86187cc8d8272e3bd8ef Mon Sep 17 00:00:00 2001 | ||
From: lipracer <[email protected]> | ||
Date: Fri, 29 Mar 2024 23:25:07 +0800 | ||
Subject: [PATCH] [mlir] fix Undefined behavior in CastInfo::castFailed with | ||
From=<MLIR interface> | ||
|
||
[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR interface> | ||
|
||
Fixes https://github.com/llvm/llvm-project/issues/86647 | ||
Fixes https://github.com/llvm/llvm-project/issues/86647 | ||
|
||
add CastInfo to support cast Interface to Op | ||
--- | ||
config.sh | 10 +++ | ||
mlir/include/mlir/IR/OpDefinition.h | 71 ++++++++++++++++++++++ | ||
mlir/include/mlir/TableGen/Class.h | 2 + | ||
mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++ | ||
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +- | ||
mlir/unittests/IR/InterfaceTest.cpp | 48 +++++++++++++++ | ||
6 files changed, 142 insertions(+), 1 deletion(-) | ||
create mode 100644 config.sh | ||
|
||
diff --git a/config.sh b/config.sh | ||
new file mode 100644 | ||
index 000000000000..55ab08224a32 | ||
--- /dev/null | ||
+++ b/config.sh | ||
@@ -0,0 +1,10 @@ | ||
+cmake -G Ninja llvm -B build \ | ||
+ -DCMAKE_C_COMPILER=clang \ | ||
+ -DCMAKE_CXX_COMPILER=clang++ \ | ||
+ -DLLVM_ENABLE_LLD=OFF \ | ||
+ -DLLVM_ENABLE_PROJECTS=mlir \ | ||
+ -DLLVM_BUILD_EXAMPLES=ON \ | ||
+ -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \ | ||
+ -DCMAKE_BUILD_TYPE=Release \ | ||
+ -DLLVM_ENABLE_ASSERTIONS=ON \ | ||
+ -DMLIR_INCLUDE_INTEGRATION_TESTS=ON | ||
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h | ||
index bd68c2744574..5ba39b80b513 100644 | ||
index 59f094d66909..52aac19289cf 100644 | ||
--- a/mlir/include/mlir/IR/OpDefinition.h | ||
+++ b/mlir/include/mlir/IR/OpDefinition.h | ||
@@ -22,6 +22,7 @@ | ||
|
@@ -16,9 +43,9 @@ index bd68c2744574..5ba39b80b513 100644 | |
#include "mlir/IR/Operation.h" | ||
+#include "llvm/Support/Casting.h" | ||
#include "llvm/Support/PointerLikeTypeTraits.h" | ||
|
||
#include <optional> | ||
@@ -2110,6 +2111,34 @@ struct DenseMapInfo<T, | ||
@@ -2142,6 +2143,76 @@ struct DenseMapInfo<T, | ||
} | ||
static bool isEqual(T lhs, T rhs) { return lhs == rhs; } | ||
}; | ||
|
@@ -36,7 +63,7 @@ index bd68c2744574..5ba39b80b513 100644 | |
+ void>> : NullableValueCastFailed<To>, | ||
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { | ||
+ | ||
+ static bool isPossible(From &val) { | ||
+ static inline bool isPossible(From &val) { | ||
+ if constexpr (std::is_same_v<To, From>) | ||
+ return true; | ||
+ else | ||
|
@@ -45,16 +72,92 @@ index bd68c2744574..5ba39b80b513 100644 | |
+ const_cast<std::remove_const_t<From> &>(val).getOperation()); | ||
+ } | ||
+ | ||
+ static To doCast(From &val) { | ||
+ static inline To doCast(From &val) { | ||
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation()); | ||
+ } | ||
+}; | ||
+ | ||
+template <typename OpT, typename = void> | ||
+struct is_concrete_op_type : public std::false_type {}; | ||
+ | ||
+template <typename OpT, template <typename T> typename... Traits> | ||
+constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) { | ||
+ return mlir::Op<OpT, Traits...>(nullptr); | ||
+} | ||
+ | ||
+template <typename OpT> | ||
+using concrete_op_base_type = | ||
+ decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits())); | ||
+ | ||
+template <typename OpT> | ||
+struct is_concrete_op_type< | ||
+ OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>> | ||
+ : public std::true_type {}; | ||
+ | ||
+template <typename To, typename From> | ||
+struct CastInfo< | ||
+ To, From, | ||
+ std::enable_if_t< | ||
+ is_concrete_op_type<To>() && | ||
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>, | ||
+ typename std::remove_const_t< | ||
+ From>::InterfaceTraits>, | ||
+ std::remove_const_t<From>>>> | ||
+ : NullableValueCastFailed<To>, | ||
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { | ||
+ | ||
+ static inline bool isPossible(From &val) { | ||
+ if constexpr (std::is_same_v<To, From>) | ||
+ return true; | ||
+ else | ||
+ return isa<To>( | ||
+ const_cast<std::remove_const_t<From> &>(val).getOperation()); | ||
+ } | ||
+ | ||
+ static inline To doCast(From &val) { | ||
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation()); | ||
+ } | ||
+}; | ||
+ | ||
} // namespace llvm | ||
|
||
#endif | ||
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h | ||
index 92fec6a3b11d..7616f56aa2e3 100644 | ||
--- a/mlir/include/mlir/TableGen/Class.h | ||
+++ b/mlir/include/mlir/TableGen/Class.h | ||
@@ -520,6 +520,8 @@ public: | ||
/// Write the parent class declaration. | ||
void writeTo(raw_indented_ostream &os) const; | ||
|
||
+ friend class OpClass; | ||
+ | ||
private: | ||
/// The fully resolved C++ name of the parent class. | ||
std::string name; | ||
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp | ||
index 60fa1833ce62..5426302dfed3 100644 | ||
--- a/mlir/tools/mlir-tblgen/OpClass.cpp | ||
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp | ||
@@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration, | ||
} | ||
|
||
void OpClass::finalize() { | ||
+ std::string traitList; | ||
+ llvm::raw_string_ostream os(traitList); | ||
+ iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1, | ||
+ std::end(parent.templateParams)); | ||
+ llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) { | ||
+ os << trait << "<" << getClassName().str() << ">"; | ||
+ }); | ||
+ declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">"); | ||
Class::finalize(); | ||
+ | ||
declare<VisibilityDeclaration>(Visibility::Public); | ||
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition); | ||
} | ||
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | ||
index 2a7406f42f34..c6409e9ec30e 100644 | ||
index 4b06b92fbc8a..a1cae23c1df9 100644 | ||
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | ||
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | ||
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { | ||
|
@@ -66,27 +169,28 @@ index 2a7406f42f34..c6409e9ec30e 100644 | |
+ " using InterfaceTraits = detail::{2};\n", | ||
interfaceName, interfaceName, interfaceTraitsName, | ||
interfaceBaseType); | ||
|
||
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp | ||
index 5ab4d9a10623..7012da669248 100644 | ||
index 42196b003e7d..c9ae6938e8b4 100644 | ||
--- a/mlir/unittests/IR/InterfaceTest.cpp | ||
+++ b/mlir/unittests/IR/InterfaceTest.cpp | ||
@@ -16,6 +16,9 @@ | ||
#include "../../test/lib/Dialect/Test/TestAttributes.h" | ||
@@ -17,6 +17,10 @@ | ||
#include "../../test/lib/Dialect/Test/TestDialect.h" | ||
#include "../../test/lib/Dialect/Test/TestOps.h" | ||
#include "../../test/lib/Dialect/Test/TestTypes.h" | ||
+#include "mlir/Dialect/Arith/IR/Arith.h" | ||
+#include "mlir/Dialect/SCF/IR/SCF.h" | ||
+#include "mlir/Parser/Parser.h" | ||
+#include "llvm/ADT/TypeSwitch.h" | ||
|
||
using namespace mlir; | ||
using namespace test; | ||
@@ -83,3 +86,40 @@ TEST(InterfaceTest, TestImplicitConversion) { | ||
@@ -84,3 +88,47 @@ TEST(InterfaceTest, TestImplicitConversion) { | ||
typeA = typeB; | ||
EXPECT_EQ(typeA, typeB); | ||
} | ||
+ | ||
+TEST(OperationInterfaceTest, CastOpToInterface) { | ||
+TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) { | ||
+ DialectRegistry registry; | ||
+ MLIRContext ctx; | ||
+ | ||
|
@@ -103,13 +207,20 @@ index 5ab4d9a10623..7012da669248 100644 | |
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx); | ||
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front(); | ||
+ | ||
+ static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>, | ||
+ arith::AddIOp>, | ||
+ ""); | ||
+ static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), ""); | ||
+ static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), ""); | ||
+ | ||
+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op); | ||
+ | ||
+ bool constantOp = | ||
+ llvm::TypeSwitch<OpAsmOpInterface, bool>(interface) | ||
+ .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) { | ||
+ return std::is_same_v<decltype(op), arith::ConstantOp>; | ||
+ }); | ||
+ bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface) | ||
+ .Case<arith::AddIOp, arith::ConstantOp>([&](auto op) { | ||
+ bool is_same = | ||
+ std::is_same_v<decltype(op), arith::ConstantOp>; | ||
+ return is_same; | ||
+ }); | ||
+ | ||
+ EXPECT_TRUE(constantOp); | ||
+ | ||
|
@@ -122,3 +233,5 @@ index 5ab4d9a10623..7012da669248 100644 | |
+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface)); | ||
+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface)); | ||
+} | ||
-- | ||
2.25.1 |