diff --git a/.github/workflows/buildAndTestStructured.yml b/.github/workflows/buildAndTestStructured.yml index 3e64d38e088d..0c42cb781015 100644 --- a/.github/workflows/buildAndTestStructured.yml +++ b/.github/workflows/buildAndTestStructured.yml @@ -49,6 +49,12 @@ jobs: path: sandbox submodules: recursive + - name: Patch LLVM with WIP patch for testing + run: | + cd ${STRUCTURED_MAIN_SRC_DIR}/third_party/llvm-project + patch -p1 < ../../cast.patch + git diff + - name: Install Ninja uses: llvm/actions/install-ninja@6a57890d0e3f9f35dfc72e7e48bc5e1e527cdd6c # Jan 17 diff --git a/cast.patch b/cast.patch new file mode 100644 index 000000000000..29e296128f8a --- /dev/null +++ b/cast.patch @@ -0,0 +1,220 @@ +From d50b3d3473fb9a04e1f57797e0d043719a571969 Mon Sep 17 00:00:00 2001 +From: lipracer +Date: Fri, 29 Mar 2024 23:25:07 +0800 +Subject: [PATCH] [mlir] fix Undefined behavior in CastInfo::castFailed with + From= + +Fixes https://github.com/llvm/llvm-project/issues/86647 + +add CastInfo to support cast Interface to Op +--- + mlir/include/mlir/IR/OpDefinition.h | 71 ++++++++++++++++++++++ + mlir/include/mlir/TableGen/Class.h | 2 + + mlir/tools/mlir-tblgen/OpClass.cpp | 9 +++ + mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 3 +- + mlir/unittests/IR/InterfaceTest.cpp | 48 +++++++++++++++ + 5 files changed, 132 insertions(+), 1 deletion(-) + +diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h +index 59f094d66909..52aac19289cf 100644 +--- a/mlir/include/mlir/IR/OpDefinition.h ++++ b/mlir/include/mlir/IR/OpDefinition.h +@@ -22,6 +22,7 @@ + #include "mlir/IR/Dialect.h" + #include "mlir/IR/ODSSupport.h" + #include "mlir/IR/Operation.h" ++#include "llvm/Support/Casting.h" + #include "llvm/Support/PointerLikeTypeTraits.h" + + #include +@@ -2142,6 +2143,76 @@ struct DenseMapInfo ++struct CastInfo< ++ To, From, ++ std::enable_if_t< ++ std::is_base_of_v, ++ To> && ++ std::is_base_of_v, ++ typename std::remove_const_t< ++ From>::InterfaceTraits>, ++ std::remove_const_t>, ++ void>> : NullableValueCastFailed, ++ DefaultDoCastIfPossible> { ++ ++ static inline bool isPossible(From &val) { ++ if constexpr (std::is_same_v) ++ return true; ++ else ++ return mlir::OpInterface:: ++ InterfaceBase::classof( ++ const_cast &>(val).getOperation()); ++ } ++ ++ static inline To doCast(From &val) { ++ return To(const_cast &>(val).getOperation()); ++ } ++}; ++ ++template ++struct is_concrete_op_type : public std::false_type {}; ++ ++template typename... Traits> ++constexpr auto concrete_op_base_type_impl(std::tuple...>) { ++ return mlir::Op(nullptr); ++} ++ ++template ++using concrete_op_base_type = ++ decltype(concrete_op_base_type_impl(typename OpT::traits())); ++ ++template ++struct is_concrete_op_type< ++ OpT, std::enable_if_t, OpT>>> ++ : public std::true_type {}; ++ ++template ++struct CastInfo< ++ To, From, ++ std::enable_if_t< ++ is_concrete_op_type() && ++ std::is_base_of_v, ++ typename std::remove_const_t< ++ From>::InterfaceTraits>, ++ std::remove_const_t>>> ++ : NullableValueCastFailed, ++ DefaultDoCastIfPossible> { ++ ++ static inline bool isPossible(From &val) { ++ if constexpr (std::is_same_v) ++ return true; ++ else ++ return isa( ++ const_cast &>(val).getOperation()); ++ } ++ ++ static inline To doCast(From &val) { ++ return To(const_cast &>(val).getOperation()); ++ } ++}; ++ + } // namespace llvm + + #endif +diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h +index 92fec6a3b11d..7616f56aa2e3 100644 +--- a/mlir/include/mlir/TableGen/Class.h ++++ b/mlir/include/mlir/TableGen/Class.h +@@ -520,6 +520,8 @@ public: + /// Write the parent class declaration. + void writeTo(raw_indented_ostream &os) const; + ++ friend class OpClass; ++ + private: + /// The fully resolved C++ name of the parent class. + std::string name; +diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp +index 60fa1833ce62..5426302dfed3 100644 +--- a/mlir/tools/mlir-tblgen/OpClass.cpp ++++ b/mlir/tools/mlir-tblgen/OpClass.cpp +@@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration, + } + + void OpClass::finalize() { ++ std::string traitList; ++ llvm::raw_string_ostream os(traitList); ++ iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1, ++ std::end(parent.templateParams)); ++ llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) { ++ os << trait << "<" << getClassName().str() << ">"; ++ }); ++ declare("traits", "std::tuple<" + traitList + ">"); + Class::finalize(); ++ + declare(Visibility::Public); + declare(extraClassDeclaration, extraClassDefinition); + } +diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +index 4b06b92fbc8a..a1cae23c1df9 100644 +--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp ++++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +@@ -544,7 +544,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { + // Emit the main interface class declaration. + os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" + "public:\n" +- " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", ++ " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" ++ " using InterfaceTraits = detail::{2};\n", + interfaceName, interfaceName, interfaceTraitsName, + interfaceBaseType); + +diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp +index 42196b003e7d..c9ae6938e8b4 100644 +--- a/mlir/unittests/IR/InterfaceTest.cpp ++++ b/mlir/unittests/IR/InterfaceTest.cpp +@@ -17,6 +17,10 @@ + #include "../../test/lib/Dialect/Test/TestDialect.h" + #include "../../test/lib/Dialect/Test/TestOps.h" + #include "../../test/lib/Dialect/Test/TestTypes.h" ++#include "mlir/Dialect/Arith/IR/Arith.h" ++#include "mlir/Dialect/SCF/IR/SCF.h" ++#include "mlir/Parser/Parser.h" ++#include "llvm/ADT/TypeSwitch.h" + + using namespace mlir; + using namespace test; +@@ -84,3 +88,47 @@ TEST(InterfaceTest, TestImplicitConversion) { + typeA = typeB; + EXPECT_EQ(typeA, typeB); + } ++ ++TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) { ++ DialectRegistry registry; ++ MLIRContext ctx; ++ ++ const char *ir = R"MLIR( ++ func.func @map(%arg : tensor<1xi64>) { ++ %0 = arith.constant dense<[10]> : tensor<1xi64> ++ %1 = arith.addi %arg, %0 : tensor<1xi64> ++ return ++ } ++ )MLIR"; ++ ++ registry.insert(); ++ ctx.appendDialectRegistry(registry); ++ OwningOpRef module = parseSourceString(ir, &ctx); ++ Operation &op = cast(module->front()).getBody().front().front(); ++ ++ static_assert(std::is_base_of_v, ++ arith::AddIOp>, ++ ""); ++ static_assert(llvm::is_concrete_op_type(), ""); ++ static_assert(!llvm::is_concrete_op_type(), ""); ++ ++ OpAsmOpInterface interface = llvm::cast(op); ++ ++ bool constantOp = llvm::TypeSwitch(interface) ++ .Case([&](auto op) { ++ bool is_same = ++ std::is_same_v; ++ return is_same; ++ }); ++ ++ EXPECT_TRUE(constantOp); ++ ++ EXPECT_FALSE(llvm::isa(interface)); ++ EXPECT_FALSE(llvm::dyn_cast(interface)); ++ ++ EXPECT_TRUE(llvm::isa(interface)); ++ EXPECT_TRUE(llvm::dyn_cast(interface)); ++ ++ EXPECT_TRUE(llvm::isa(interface)); ++ EXPECT_TRUE(llvm::dyn_cast(interface)); ++} +-- +2.25.1 + diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index b58ebf099e93..0f4a12481675 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -9,6 +9,7 @@ #include "structured/Target/SubstraitPB/Export.h" #include "ProtobufUtils.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Support/LogicalResult.h" #include "structured/Dialect/Substrait/IR/Substrait.h" #include "structured/Target/SubstraitPB/Options.h" @@ -261,10 +262,13 @@ FailureOr> SubstraitExporter::exportOperation(EmitOp op) { FailureOr> SubstraitExporter::exportOperation(ExpressionOpInterface op) { - return llvm::TypeSwitch>>( - op) - .Case( - [&](auto op) { return exportOperation(op); }) + return llvm::TypeSwitch>>(op) + .Case([&](auto op) { + llvm::errs() << __PRETTY_FUNCTION__ << "\n"; + op.dump(); + return exportOperation(op); + }) .Default( [](auto op) { return op->emitOpError("not supported for export"); }); }