Skip to content

Commit

Permalink
update upstream mlir patch
Browse files Browse the repository at this point in the history
  • Loading branch information
lipracer committed Jul 12, 2024
1 parent eba9876 commit 0850775
Showing 1 changed file with 138 additions and 25 deletions.
163 changes: 138 additions & 25 deletions cast.patch
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
commit 3ea3d4bed57c4f6a35bed044bca8c1277fa2bb17
Author: lipracer <[email protected]>
Date: Fri Mar 29 23:25:07 2024 +0800
From a2113b34ed4c5bebfc2d86187cc8d8272e3bd8ef Mon Sep 17 00:00:00 2001
From: lipracer <[email protected]>
Date: Fri, 29 Mar 2024 23:25:07 +0800
Subject: [PATCH] [mlir] fix Undefined behavior in CastInfo::castFailed with
From=<MLIR interface>

[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIR interface>

Fixes https://github.com/llvm/llvm-project/issues/86647
Fixes https://github.com/llvm/llvm-project/issues/86647

add CastInfo to support cast Interface to Op
---
config.sh | 10 +++
mlir/include/mlir/IR/OpDefinition.h | 71 ++++++++++++++++++++++
mlir/include/mlir/TableGen/Class.h | 2 +
mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +-
mlir/unittests/IR/InterfaceTest.cpp | 48 +++++++++++++++
6 files changed, 142 insertions(+), 1 deletion(-)
create mode 100644 config.sh

diff --git a/config.sh b/config.sh
new file mode 100644
index 000000000000..55ab08224a32
--- /dev/null
+++ b/config.sh
@@ -0,0 +1,10 @@
+cmake -G Ninja llvm -B build \
+ -DCMAKE_C_COMPILER=clang \
+ -DCMAKE_CXX_COMPILER=clang++ \
+ -DLLVM_ENABLE_LLD=OFF \
+ -DLLVM_ENABLE_PROJECTS=mlir \
+ -DLLVM_BUILD_EXAMPLES=ON \
+ -DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DLLVM_ENABLE_ASSERTIONS=ON \
+ -DMLIR_INCLUDE_INTEGRATION_TESTS=ON
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bd68c2744574..5ba39b80b513 100644
index 59f094d66909..52aac19289cf 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -22,6 +22,7 @@
Expand All @@ -16,9 +43,9 @@ index bd68c2744574..5ba39b80b513 100644
#include "mlir/IR/Operation.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

#include <optional>
@@ -2110,6 +2111,34 @@ struct DenseMapInfo<T,
@@ -2142,6 +2143,76 @@ struct DenseMapInfo<T,
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};
Expand All @@ -36,7 +63,7 @@ index bd68c2744574..5ba39b80b513 100644
+ void>> : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static bool isPossible(From &val) {
+ static inline bool isPossible(From &val) {
+ if constexpr (std::is_same_v<To, From>)
+ return true;
+ else
Expand All @@ -45,16 +72,92 @@ index bd68c2744574..5ba39b80b513 100644
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+
+ static To doCast(From &val) {
+ static inline To doCast(From &val) {
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+};
+
+template <typename OpT, typename = void>
+struct is_concrete_op_type : public std::false_type {};
+
+template <typename OpT, template <typename T> typename... Traits>
+constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
+ return mlir::Op<OpT, Traits...>(nullptr);
+}
+
+template <typename OpT>
+using concrete_op_base_type =
+ decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));
+
+template <typename OpT>
+struct is_concrete_op_type<
+ OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
+ : public std::true_type {};
+
+template <typename To, typename From>
+struct CastInfo<
+ To, From,
+ std::enable_if_t<
+ is_concrete_op_type<To>() &&
+ std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
+ typename std::remove_const_t<
+ From>::InterfaceTraits>,
+ std::remove_const_t<From>>>>
+ : NullableValueCastFailed<To>,
+ DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
+
+ static inline bool isPossible(From &val) {
+ if constexpr (std::is_same_v<To, From>)
+ return true;
+ else
+ return isa<To>(
+ const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+
+ static inline To doCast(From &val) {
+ return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
+ }
+};
+
} // namespace llvm

#endif
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 92fec6a3b11d..7616f56aa2e3 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -520,6 +520,8 @@ public:
/// Write the parent class declaration.
void writeTo(raw_indented_ostream &os) const;

+ friend class OpClass;
+
private:
/// The fully resolved C++ name of the parent class.
std::string name;
diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp
index 60fa1833ce62..5426302dfed3 100644
--- a/mlir/tools/mlir-tblgen/OpClass.cpp
+++ b/mlir/tools/mlir-tblgen/OpClass.cpp
@@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
}

void OpClass::finalize() {
+ std::string traitList;
+ llvm::raw_string_ostream os(traitList);
+ iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
+ std::end(parent.templateParams));
+ llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
+ os << trait << "<" << getClassName().str() << ">";
+ });
+ declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
Class::finalize();
+
declare<VisibilityDeclaration>(Visibility::Public);
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
}
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 2a7406f42f34..c6409e9ec30e 100644
index 4b06b92fbc8a..a1cae23c1df9 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
Expand All @@ -66,27 +169,28 @@ index 2a7406f42f34..c6409e9ec30e 100644
+ " using InterfaceTraits = detail::{2};\n",
interfaceName, interfaceName, interfaceTraitsName,
interfaceBaseType);

diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 5ab4d9a10623..7012da669248 100644
index 42196b003e7d..c9ae6938e8b4 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -16,6 +16,9 @@
#include "../../test/lib/Dialect/Test/TestAttributes.h"
@@ -17,6 +17,10 @@
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace test;
@@ -83,3 +86,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
@@ -84,3 +88,47 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}
+
+TEST(OperationInterfaceTest, CastOpToInterface) {
+TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
+ DialectRegistry registry;
+ MLIRContext ctx;
+
Expand All @@ -103,13 +207,20 @@ index 5ab4d9a10623..7012da669248 100644
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+ Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
+
+ static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
+ arith::AddIOp>,
+ "");
+ static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
+ static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");
+
+ OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);
+
+ bool constantOp =
+ llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+ .Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
+ return std::is_same_v<decltype(op), arith::ConstantOp>;
+ });
+ bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
+ .Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
+ bool is_same =
+ std::is_same_v<decltype(op), arith::ConstantOp>;
+ return is_same;
+ });
+
+ EXPECT_TRUE(constantOp);
+
Expand All @@ -122,3 +233,5 @@ index 5ab4d9a10623..7012da669248 100644
+ EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
+ EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
+}
--
2.25.1

0 comments on commit 0850775

Please sign in to comment.