Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] fix Undefined behavior in CastInfo::castFailed with From=<MLIRinterface> #87145

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ODSSupport.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

#include <optional>
Expand Down Expand Up @@ -2142,6 +2143,46 @@ struct DenseMapInfo<T,
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};

/// The reason why these conditions are not directly used in specialized
/// parameters is that some compilers do not support short circuits between
/// several conditions.
template <typename T, bool = is_incomplete_v<T>>
struct is_complete_and_derive_from_state {
constexpr static bool value = std::is_base_of_v<mlir::OpState, T>;
};

template <typename T>
struct is_complete_and_derive_from_state<T, true> {
constexpr static bool value = false;
};

/// Add support for llvm style casts.
/// We provide a cast between To and From if To and From is mlir::OpState or
/// derives from it. To avoid some pre declared types matching here, we have
/// added a condition for whether there is a complete type defintion.
template <typename To, typename From>
struct CastInfo<To, From,
std::enable_if_t<is_complete_and_derive_from_state<To>::value &&
is_complete_and_derive_from_state<
std::remove_const_t<From>>::value,
void>>
: NullableValueCastFailed<To>,
DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {

static inline bool isPossible(From &val) {
if constexpr (std::is_same_v<To, From>)
return true;
else
return isa<To>(
const_cast<std::remove_const_t<From> &>(val).getOperation());
}

static inline To doCast(From &val) {
return To(const_cast<std::remove_const_t<From> &>(val).getOperation());
}
};

} // namespace llvm

#endif
40 changes: 40 additions & 0 deletions mlir/unittests/IR/InterfaceTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "../../test/lib/Dialect/Test/TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace test;
Expand Down Expand Up @@ -84,3 +87,40 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}

TEST(OperationInterfaceTest, CastInterfaceToOpOrInterface) {
DialectRegistry registry;
MLIRContext ctx;

const char *ir = R"MLIR(
func.func @map(%arg : tensor<1xi64>) {
%0 = arith.constant dense<[10]> : tensor<1xi64>
%1 = arith.addi %arg, %0 : tensor<1xi64>
return
}
)MLIR";

registry.insert<func::FuncDialect, arith::ArithDialect>();
ctx.appendDialectRegistry(registry);
OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
Operation &op = cast<func::FuncOp>(module->front()).getBody().front().front();
OpAsmOpInterface interface = llvm::cast<OpAsmOpInterface>(op);

bool constantOp = llvm::TypeSwitch<OpAsmOpInterface, bool>(interface)
.Case<arith::AddIOp, arith::ConstantOp>([&](auto op) {
bool is_same =
std::is_same_v<decltype(op), arith::ConstantOp>;
return is_same;
});

EXPECT_TRUE(constantOp);

EXPECT_FALSE(llvm::isa<VectorUnrollOpInterface>(interface));
EXPECT_FALSE(llvm::dyn_cast<VectorUnrollOpInterface>(interface));

EXPECT_TRUE(llvm::isa<InferTypeOpInterface>(interface));
EXPECT_TRUE(llvm::dyn_cast<InferTypeOpInterface>(interface));

EXPECT_TRUE(llvm::isa<OpAsmOpInterface>(interface));
EXPECT_TRUE(llvm::dyn_cast<OpAsmOpInterface>(interface));
}
lipracer marked this conversation as resolved.
Show resolved Hide resolved
Loading