diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 59f094d6690991..7bfdbe09ed961b 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -22,6 +22,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/ODSSupport.h" #include "mlir/IR/Operation.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include @@ -2142,6 +2143,46 @@ struct DenseMapInfo> +struct is_complete_and_derive_from_state { + constexpr static bool value = std::is_base_of_v; +}; + +template +struct is_complete_and_derive_from_state { + constexpr static bool value = false; +}; + +/// Add support for llvm style casts. +/// We provide a cast between To and From if To and From is mlir::OpState or +/// derives from it. To avoid some pre declared types matching here, we have +/// added a condition for whether there is a complete type defintion. +template +struct CastInfo::value && + is_complete_and_derive_from_state< + std::remove_const_t>::value, + void>> + : NullableValueCastFailed, + DefaultDoCastIfPossible> { + + static inline bool isPossible(From &val) { + if constexpr (std::is_same_v) + return true; + else + return isa( + const_cast &>(val).getOperation()); + } + + static inline To doCast(From &val) { + return To(const_cast &>(val).getOperation()); + } +}; + } // namespace llvm #endif diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 42196b003e7dad..73d43e87556380 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -17,6 +17,9 @@ #include "../../test/lib/Dialect/Test/TestDialect.h" #include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestTypes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Parser/Parser.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace test; @@ -84,3 +87,40 @@ TEST(InterfaceTest, TestImplicitConversion) { typeA = typeB; EXPECT_EQ(typeA, typeB); } + +TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) { + DialectRegistry registry; + MLIRContext ctx; + + const char *ir = R"MLIR( + func.func @map(%arg : tensor<1xi64>) { + %0 = arith.constant dense<[10]> : tensor<1xi64> + %1 = arith.addi %arg, %0 : tensor<1xi64> + return + } + )MLIR"; + + registry.insert(); + ctx.appendDialectRegistry(registry); + OwningOpRef module = parseSourceString(ir, &ctx); + Operation &op = cast(module->front()).getBody().front().front(); + OpAsmOpInterface interface = llvm::cast(op); + + bool constantOp = llvm::TypeSwitch(interface) + .Case([&](auto op) { + bool is_same = + std::is_same_v; + return is_same; + }); + + EXPECT_TRUE(constantOp); + + EXPECT_FALSE(llvm::isa(interface)); + EXPECT_FALSE(llvm::dyn_cast(interface)); + + EXPECT_TRUE(llvm::isa(interface)); + EXPECT_TRUE(llvm::dyn_cast(interface)); + + EXPECT_TRUE(llvm::isa(interface)); + EXPECT_TRUE(llvm::dyn_cast(interface)); +}