Skip to content

Commit

Permalink
add CastInfo to support cast Interface to Op
Browse files Browse the repository at this point in the history
  • Loading branch information
lipracer committed Aug 27, 2024
1 parent ab2a26d commit 3be1ad2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 8 deletions.
53 changes: 51 additions & 2 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,9 @@ struct DenseMapInfo<T,
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};

/// Add support for llvm style casts.
/// We provide a cast between To and From if To and From is mlir::OpInterface or
/// derives from it.
template <typename To, typename From>
struct CastInfo<
To, From,
Expand All @@ -2157,7 +2160,7 @@ struct CastInfo<
void>> : NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {

static bool isPossible(From &val) {
static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
Expand All @@ -2166,7 +2169,53 @@ struct CastInfo<
const_cast<std::remove_const_t<From> &>(val).getOperation());
}

static To doCast(From &val) {
static inline To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};

template <typename OpT, typename = void>
struct is_concrete_op_type : public std::false_type {};

template <typename OpT, template <typename T> typename... Traits>
constexpr auto concrete_op_base_type_impl(std::tuple<Traits<OpT>...>) {
return mlir::Op<OpT, Traits...>(nullptr);
}

template <typename OpT>
using concrete_op_base_type =
decltype(concrete_op_base_type_impl<OpT>(typename OpT::traits()));

template <typename OpT>
struct is_concrete_op_type<
OpT, std::enable_if_t<std::is_base_of_v<concrete_op_base_type<OpT>, OpT>>>
: public std::true_type {};

/// Add support for llvm style casts.
/// We provide a cast between To and From if To is mlir::Op<ConcreteType,
/// Trait0, Trait1, ...> or derives from it and From is mlir::OpInterface or
/// derives from it.
template <typename To, typename From>
struct CastInfo<
To, From,
std::enable_if_t<
is_concrete_op_type<To>() &&
std::is_base_of_v<mlir::OpInterface<std::remove_const_t<From>,
typename std::remove_const_t<
From>::InterfaceTraits>,
std::remove_const_t<From>>>>
: NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {

static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
return isa<To>(
const_cast<std::remove_const_t<From> &>(val).getOperation());
}

static inline To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/TableGen/Class.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ class ParentClass {
/// Write the parent class declaration.
void writeTo(raw_indented_ostream &os) const;

friend class OpClass;

private:
/// The fully resolved C++ name of the parent class.
std::string name;
Expand Down
9 changes: 9 additions & 0 deletions mlir/tools/mlir-tblgen/OpClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@ OpClass::OpClass(StringRef name, std::string extraClassDeclaration,
}

void OpClass::finalize() {
std::string traitList;
llvm::raw_string_ostream os(traitList);
iterator_range parentTemplateParams(std::begin(parent.templateParams) + 1,
std::end(parent.templateParams));
llvm::interleaveComma(parentTemplateParams, os, [&](auto &trait) {
os << trait << "<" << getClassName().str() << ">";
});
declare<UsingDeclaration>("traits", "std::tuple<" + traitList + ">");
Class::finalize();

declare<VisibilityDeclaration>(Visibility::Public);
declare<ExtraClassDeclaration>(extraClassDeclaration, extraClassDefinition);
}
19 changes: 13 additions & 6 deletions mlir/unittests/IR/InterfaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ TEST(InterfaceTest, TestImplicitConversion) {
EXPECT_EQ(typeA, typeB);
}

TEST(OperationInterfaceTest, CastOpToInterface) {
TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
DialectRegistry registry;
MLIRContext ctx;

Expand All @@ -105,13 +105,20 @@ TEST(OperationInterfaceTest, CastOpToInterface) {
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();

static_assert(std::is_base_of_v<llvm::concrete_op_base_type<arith::AddIOp>,
arith::AddIOp>,
"");
static_assert(llvm::is_concrete_op_type<arith::AddIOp>(), "");
static_assert(!llvm::is_concrete_op_type<OpAsmOpInterface>(), "");

OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);

bool constantOp =
llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
.Case<VectorUnrollOpInterface, arith::ConstantOp>([&](auto op) {
return std::is_same_v<decltype(op), arith::ConstantOp>;
});
bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
.Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
bool is_same =
std::is_same_v<decltype(op), arith::ConstantOp>;
return is_same;
});

EXPECT_TRUE(constantOp);

Expand Down

0 comments on commit 3be1ad2

Please sign in to comment.