Skip to content

Commit

Permalink
[RTTI] Replace std::dynamic_(pointer)?_casts with ov::as_type_(ptr)? …
Browse files Browse the repository at this point in the history
…- FEs (#28397)

### Details:
- Replaced `std::dynamic_cast` and `std::dynamic_pointed_cast` with
`ov::as_type` or `ov::as_type_ptr` respectively in src/frontends and
src/tests directories, where applicable.

### Tickets:
 - CVS-160241

---------

Signed-off-by: Tomasz Jankowski <[email protected]>
  • Loading branch information
t-jankowski authored Jan 14, 2025
1 parent a4ee2df commit 42cb92e
Show file tree
Hide file tree
Showing 80 changed files with 220 additions and 242 deletions.
2 changes: 1 addition & 1 deletion src/core/include/openvino/core/type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ typename std::enable_if<
bool>::value,
bool>::type
is_type(Value value) {
return value->get_type_info().is_castable(Type::get_type_info_static());
return value && value->get_type_info().is_castable(Type::get_type_info_static());
}

/// Casts a Value* to a Type* if it is of type Type, nullptr otherwise
Expand Down
14 changes: 7 additions & 7 deletions src/frontends/ir/src/ir_deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,18 +533,18 @@ std::shared_ptr<ov::Model> ov::XmlDeserializer::parse_function(const pugi::xml_n
auto node = create_node(inputs, p.xml, weights, p.params);
id_to_node[layer_id] = node;

if (const auto& parameter_node = std::dynamic_pointer_cast<ov::op::v0::Parameter>(node)) {
if (const auto& parameter_node = ov::as_type_ptr<ov::op::v0::Parameter>(node)) {
io_map.inputs.insert({layer_id, func_nodes.parameters.size()});
func_nodes.parameters.emplace_back(parameter_node);
}

if (const auto& result_node = std::dynamic_pointer_cast<ov::op::v0::Result>(node)) {
if (const auto& result_node = ov::as_type_ptr<ov::op::v0::Result>(node)) {
io_map.outputs.insert({layer_id, func_nodes.results.size()});
func_nodes.results.emplace_back(result_node);
}

if (const auto& sink = std::dynamic_pointer_cast<ov::op::Sink>(node)) {
auto subgraph_op = std::dynamic_pointer_cast<ov::op::util::MultiSubGraphOp>(node);
if (const auto& sink = ov::as_type_ptr<ov::op::Sink>(node)) {
auto subgraph_op = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(node);
if (subgraph_op) {
for (const auto& body_model : subgraph_op->get_functions()) {
if (body_model->get_sinks().size()) {
Expand All @@ -557,7 +557,7 @@ std::shared_ptr<ov::Model> ov::XmlDeserializer::parse_function(const pugi::xml_n
}
}

if (const auto& read_value = std::dynamic_pointer_cast<ov::op::util::ReadValueBase>(node)) {
if (const auto& read_value = ov::as_type_ptr<ov::op::util::ReadValueBase>(node)) {
variable_id_to_read_value[read_value->get_variable_id()] = read_value;
}

Expand All @@ -569,7 +569,7 @@ std::shared_ptr<ov::Model> ov::XmlDeserializer::parse_function(const pugi::xml_n
func_nodes.parameters,
pugixml::get_str_attr(root, "name", ""));
for (const auto& sink : func_nodes.sinks) {
if (const auto& assign = std::dynamic_pointer_cast<ov::op::util::AssignBase>(sink)) {
if (const auto& assign = ov::as_type_ptr<ov::op::util::AssignBase>(sink)) {
assign->add_control_dependency(variable_id_to_read_value.at(assign->get_variable_id()));
}
}
Expand Down Expand Up @@ -902,7 +902,7 @@ std::shared_ptr<ov::Node> ov::XmlDeserializer::create_node(const std::vector<ov:
OPENVINO_THROW("Opset ", params.version, " doesn't contain the operation with type: ", type_name);
}
// Share Weights form constant blob
if (auto constant = std::dynamic_pointer_cast<ov::op::v0::Constant>(ovNode)) {
if (auto constant = ov::as_type_ptr<ov::op::v0::Constant>(ovNode)) {
constant->alloc_buffer_on_visit_attributes(false);
}
ovNode->set_arguments(inputs);
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/jax/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ Any NodeContext::get_values_from_const_input(int index) const {
index,
" does not exist.");
auto input_val = get_input(index);
if (auto input = std::dynamic_pointer_cast<JaxFrameworkNode>(input_val.get_node_shared_ptr())) {
if (auto input = ov::as_type_ptr<JaxFrameworkNode>(input_val.get_node_shared_ptr())) {
const auto& attrs = input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
return {};
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/onnx/frontend/src/core/null_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::shared_ptr<ov::Node> NullNode::clone_with_new_inputs(const ov::OutputVector
} // namespace ov

bool ov::op::util::is_null(const ov::Node* node) {
return dynamic_cast<const ov::frontend::onnx::NullNode*>(node) != nullptr;
return ov::as_type<const ov::frontend::onnx::NullNode>(node) != nullptr;
}

bool ov::op::util::is_null(const std::shared_ptr<ov::Node>& node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
CHECK_VALID_NODE(node, blob_size > 0, "Wrong blob size: ", blob_size);
// in documentation: ...Input B is a 2D constant Matrix.
CHECK_VALID_NODE(node,
dynamic_cast<v0::Constant*>(b_quantized.get_node()) != nullptr,
ov::as_type<v0::Constant>(b_quantized.get_node()) != nullptr,
"MatMulNBits limitation: accepting only a constant as a B input");
CHECK_VALID_NODE(node,
b_quantized.get_partial_shape().rank() == 3,
Expand Down Expand Up @@ -112,7 +112,7 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
}

{
const auto b_const = std::dynamic_pointer_cast<v0::Constant>(b_quantized.get_node_shared_ptr());
const auto b_const = ov::as_type_ptr<v0::Constant>(b_quantized.get_node_shared_ptr());

ov::Output<ov::Node> casted_b;
ov::Shape casted_b_shape;
Expand Down
8 changes: 4 additions & 4 deletions src/frontends/onnx/frontend/src/utils/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ bool collect_translation_exceptions(const std::shared_ptr<ov::Model>& partially_
};

for (const auto& node : partially_converted->get_ordered_ops()) {
if (const auto& fw_node = std::dynamic_pointer_cast<ov::frontend::onnx::ONNXFrameworkNode>(node)) {
if (const auto& fw_node = ov::as_type_ptr<ov::frontend::onnx::ONNXFrameworkNode>(node)) {
const auto& attrs = fw_node->get_attrs();
auto node_name = attrs.get_opset_name() + "." + attrs.get_type_name();
if (unsupported_operations->count(node_name) > 0) {
Expand All @@ -230,7 +230,7 @@ bool collect_translation_exceptions(const std::shared_ptr<ov::Model>& partially_

print_unsupported(fw_node);
unsupported_operations->insert(node_name);
} else if (const auto& fw_node = std::dynamic_pointer_cast<ov::frontend::onnx::NotSupportedONNXNode>(node)) {
} else if (const auto& fw_node = ov::as_type_ptr<ov::frontend::onnx::NotSupportedONNXNode>(node)) {
const auto& attrs = fw_node->get_attrs();

if (fw_node->additional_error_message().empty()) {
Expand All @@ -248,7 +248,7 @@ bool collect_translation_exceptions(const std::shared_ptr<ov::Model>& partially_
failures->insert(node_fail);
}

} else if (const auto& if_node = std::dynamic_pointer_cast<ov::op::v8::If>(node)) {
} else if (const auto& if_node = ov::as_type_ptr<ov::op::v8::If>(node)) {
collect_translation_exceptions(if_node->get_then_body(),
telemetry,
output_stream,
Expand All @@ -259,7 +259,7 @@ bool collect_translation_exceptions(const std::shared_ptr<ov::Model>& partially_
output_stream,
unsupported_operations,
failures);
} else if (const auto& loop_node = std::dynamic_pointer_cast<ov::op::v5::Loop>(node)) {
} else if (const auto& loop_node = ov::as_type_ptr<ov::op::v5::Loop>(node)) {
collect_translation_exceptions(loop_node->get_function(),
telemetry,
output_stream,
Expand Down
6 changes: 3 additions & 3 deletions src/frontends/onnx/frontend/src/utils/onnx_internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void remove_dangling_parameters(std::shared_ptr<ov::Model>& model) {
std::all_of(parameter_users.begin(),
parameter_users.end(),
[](const std::shared_ptr<ov::Node>& node) -> bool {
return std::dynamic_pointer_cast<ov::frontend::onnx::ONNXFrameworkNode>(node) != nullptr;
return ov::as_type_ptr<ov::frontend::onnx::ONNXFrameworkNode>(node) != nullptr;
});
if (is_dangling_parameter) {
model->remove_parameter(parameter);
Expand Down Expand Up @@ -69,8 +69,8 @@ void convert_decoded_model(std::shared_ptr<ov::Model> model) {
"' attribute in decoded model. Model probably wasn't created by FrontEnd::decode function.");
auto onnx_graph = it->second.as<std::shared_ptr<ov::frontend::onnx::Graph>>();
for (const auto& node : model->get_ordered_ops()) {
if (auto raw_node = std::dynamic_pointer_cast<ov::frontend::onnx::ONNXFrameworkNode>(node)) {
if (auto subgraph_node = std::dynamic_pointer_cast<ov::frontend::onnx::ONNXSubgraphFrameworkNode>(node)) {
if (auto raw_node = ov::as_type_ptr<ov::frontend::onnx::ONNXFrameworkNode>(node)) {
if (auto subgraph_node = ov::as_type_ptr<ov::frontend::onnx::ONNXSubgraphFrameworkNode>(node)) {
subgraph_node->infer_inputs_from_parent();
for (auto& model : subgraph_node->get_subgraph_models()) {
convert_decoded_model(model);
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/onnx/tests/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST(ONNXConversionExtensionTest, custom_op_with_custom_domain) {
OV_ASSERT_NO_THROW(model = onnx::tests::convert_model("missing_op_domain.onnx", ext));

for (const auto& op : model->get_ops()) {
if (const auto& add = std::dynamic_pointer_cast<ov::op::v1::Add>(op)) {
if (const auto& add = ov::as_type_ptr<ov::op::v1::Add>(op)) {
EXPECT_TRUE(add->get_rt_info().count("added_by_extension") == 1);
return;
}
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/onnx/tests/convert_partially_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace {
std::shared_ptr<ov::op::util::FrameworkNode> get_framework_node_with_out_name(const std::shared_ptr<ov::Model>& model,
const std::string& out_name) {
for (const auto& op : model->get_ops()) {
if (auto framework_node = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(op)) {
if (auto framework_node = ov::as_type_ptr<ov::op::util::FrameworkNode>(op)) {
for (const auto& out : op->outputs()) {
if (out.get_any_name() == out_name) {
return framework_node;
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/onnx/tests/onnx_import_convpool.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_max_pool_empty_auto_pad) {
const auto model = convert_model("max_pool_empty_auto_pad.onnx");

for (const auto& op : model->get_ops()) {
if (const auto max_pool = std::dynamic_pointer_cast<op::v8::MaxPool>(op)) {
if (const auto max_pool = ov::as_type_ptr<op::v8::MaxPool>(op)) {
EXPECT_EQ(max_pool->get_auto_pad(), op::PadType::EXPLICIT);
return;
}
Expand Down
6 changes: 3 additions & 3 deletions src/frontends/onnx/tests/onnx_tensor_names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ bool matching_node_found_in_graph(const std::vector<DerivedFromNode>& ops,
const std::unordered_set<std::string>& output_names,
int out_tensor_number = 0) {
return std::any_of(std::begin(ops), std::end(ops), [&](const DerivedFromNode op) {
if (const std::shared_ptr<OpType> casted = std::dynamic_pointer_cast<OpType>(op)) {
if (const std::shared_ptr<OpType> casted = ov::as_type_ptr<OpType>(op)) {
const auto& op_friendly_name = casted->get_friendly_name();
const auto& op_output_names = casted->get_output_tensor(out_tensor_number).get_names();
if (op_friendly_name == friendly_name && op_output_names == output_names) {
Expand All @@ -44,11 +44,11 @@ template <typename OpType, typename DerivedFromNode>
std::shared_ptr<OpType> find_by_friendly_name(const std::vector<DerivedFromNode>& ops,
const std::string& friendly_name) {
const auto it = std::find_if(std::begin(ops), std::end(ops), [&friendly_name](const DerivedFromNode& op) {
return op->get_friendly_name() == friendly_name && std::dynamic_pointer_cast<OpType>(op) != nullptr;
return op->get_friendly_name() == friendly_name && ov::as_type_ptr<OpType>(op) != nullptr;
});

if (it != std::end(ops)) {
return std::dynamic_pointer_cast<OpType>(*it);
return ov::as_type_ptr<OpType>(*it);
} else {
return nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion src/frontends/paddle/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ std::shared_ptr<ov::Model> FrontEnd::convert(const InputModel::Ptr& model) const
void FrontEnd::convert(const std::shared_ptr<ov::Model>& partiallyConverted) const {
for (const auto& node : partiallyConverted->get_ordered_ops()) {
if (ov::is_type<FrameworkNode>(node)) {
paddle::normalize_framework_node(std::dynamic_pointer_cast<FrameworkNode>(node), m_op_translators);
paddle::normalize_framework_node(ov::as_type_ptr<FrameworkNode>(node), m_op_translators);
}
}
for (const auto& result : partiallyConverted->get_results()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,20 @@ ov::frontend::paddle::pass::TransformFakeQuantize::TransformFakeQuantize() {

// check round mode
// Fallback to the PDPD FE if the round_mode is HALF_AWAY_FROM_ZERO.
const auto& round_node_cast = std::dynamic_pointer_cast<Round>(opsMap.at(round_label).get_node_shared_ptr());
const auto& round_node_cast = ov::as_type_ptr<Round>(opsMap.at(round_label).get_node_shared_ptr());
if (!round_node_cast || round_node_cast->get_mode() != Round::RoundMode::HALF_TO_EVEN) {
return false;
}

// check quantize_linear zero_point
auto zp_node_cast = std::dynamic_pointer_cast<Constant>(opsMap.at(dq_zp_label).get_node_shared_ptr());
auto zp_node_cast = ov::as_type_ptr<Constant>(opsMap.at(dq_zp_label).get_node_shared_ptr());
float zp;
if (!zp_node_cast || !ov::op::util::get_single_value(zp_node_cast, zp)) {
return false;
}

// prepare levels
const auto& clamp_node_cast = std::dynamic_pointer_cast<Clamp>(opsMap.at(q_clamp_label).get_node_shared_ptr());
const auto& clamp_node_cast = ov::as_type_ptr<Clamp>(opsMap.at(q_clamp_label).get_node_shared_ptr());
if (!clamp_node_cast) {
return false;
}
Expand All @@ -93,7 +93,7 @@ ov::frontend::paddle::pass::TransformFakeQuantize::TransformFakeQuantize() {
const auto levels = high_range - low_range + 1;

// get the scale
const auto& scale_node_cast = std::dynamic_pointer_cast<Constant>(
const auto& scale_node_cast = ov::as_type_ptr<Constant>(
opsMap.at(q_real_scale_label).get_node_shared_ptr()->get_input_node_shared_ptr(0));
float scale;
if (!scale_node_cast || !ov::op::util::get_single_value(scale_node_cast, scale)) {
Expand Down
3 changes: 1 addition & 2 deletions src/frontends/paddle/src/internal/pass/transform_if.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ ov::frontend::paddle::pass::TransformIf::TransformIf(std::vector<std::shared_ptr
const auto cond_label = pattern::wrap_type<ov::op::internal::ConditionalBlock>();

matcher_pass_callback callback = [funcs](pattern::Matcher& m) -> bool {
const auto conditional_block =
std::dynamic_pointer_cast<ov::op::internal::ConditionalBlock>(m.get_match_root());
const auto conditional_block = ov::as_type_ptr<ov::op::internal::ConditionalBlock>(m.get_match_root());
const auto mask_idx = conditional_block->get_input_size() - 1;
const auto cond = conditional_block->get_input_node_shared_ptr(mask_idx);

Expand Down
2 changes: 1 addition & 1 deletion src/frontends/paddle/src/internal/pass/transform_while.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ov::frontend::paddle::pass::TransformWhile::TransformWhile(std::vector<std::shar
const auto while_label = pattern::wrap_type<ov::op::internal::While>();

matcher_pass_callback callback = [functions](pattern::Matcher& m) -> bool {
const auto& while_node = std::dynamic_pointer_cast<ov::op::internal::While>(m.get_match_root());
const auto& while_node = ov::as_type_ptr<ov::op::internal::While>(m.get_match_root());
if (!while_node)
return false;
const auto& inputs = while_node->input_values();
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ std::shared_ptr<Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr& mo
auto place = inputs[i];
if (place->get_names().size() != 0 && input_names.find(place->get_names().at(0)) != input_names.end()) {
auto input = converted_model->input(place->get_names().at(0));
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(input.get_node_shared_ptr());
auto param = ov::as_type_ptr<ov::op::v0::Parameter>(input.get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(param, "Input is not a Parameter.");
update_parameter_info(param, place, converted_model);
} else {
Expand All @@ -205,7 +205,7 @@ std::shared_ptr<Model> FrontEnd::convert(const ov::frontend::InputModel::Ptr& mo
update_parameter_info(parameters[idx], fplace, converted_model);
} else {
auto input = converted_model->input(fplace->get_names().at(0));
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(input.get_node_shared_ptr());
auto param = ov::as_type_ptr<ov::op::v0::Parameter>(input.get_node_shared_ptr());
FRONT_END_GENERAL_CHECK(param, "Input is not a Parameter.");
update_parameter_info(param, fplace, converted_model);
}
Expand Down
3 changes: 3 additions & 0 deletions src/frontends/pytorch/src/helper_ops/internal_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class InternalOpDecoder : public DummyDecoder {
};

class InternalOperation : public PtFrameworkNode {
public:
OPENVINO_OP("InternalOperation", "util", PtFrameworkNode);

protected:
InternalOperation(const std::string& op_type,
const OutputVector& inputs,
Expand Down
4 changes: 2 additions & 2 deletions src/frontends/pytorch/src/helper_ops/packed_sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace pytorch {

class PackPadded : public InternalOperation {
public:
OPENVINO_OP("PackPadded", "util", ov::op::util::FrameworkNode);
OPENVINO_OP("PackPadded", "util", InternalOperation);
PackPadded(const Output<Node>& input, const Output<Node>& lengths)
: InternalOperation("prim::PackPadded", {input, lengths}, 2, "This is PackedSequence pack operation.") {
validate_and_infer_types();
Expand All @@ -27,7 +27,7 @@ class PackPadded : public InternalOperation {

class PadPacked : public InternalOperation {
public:
OPENVINO_OP("PadPacked", "util", ov::op::util::FrameworkNode);
OPENVINO_OP("PadPacked", "util", InternalOperation);
PadPacked(const Output<Node>& input, const Output<Node>& lengths)
: InternalOperation("prim::PadPacked", {input, lengths}, 2, "This is PackedSequence unpack operation.") {
validate_and_infer_types();
Expand Down
6 changes: 3 additions & 3 deletions src/frontends/pytorch/src/node_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Output<Node> NodeContext::get_input_from_visible_context(size_t index) const {
FRONT_END_GENERAL_CHECK(index < get_input_size(), "Index ", index, " is lower then number of inputs.");
auto input_tensor = get_input(static_cast<int>(index));
auto input_node = input_tensor.get_node_shared_ptr();
if (std::dynamic_pointer_cast<v0::Parameter>(input_node)) {
if (ov::as_type_ptr<v0::Parameter>(input_node)) {
// We need to look into external context for inputs that would be feed into this parameter
size_t tensor_idx = m_translate_session->decode_tensor_name(input_node->output(0));
if (m_ext_tensor_map.count(tensor_idx)) {
Expand Down Expand Up @@ -298,7 +298,7 @@ template <>
std::string NodeContext::const_input<std::string>(size_t index) const {
FRONT_END_GENERAL_CHECK(!input_is_none(index), "Input with index: ", index, " is none.");
auto input_node = get_input_from_visible_context(index).get_node_shared_ptr();
auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_node);
auto input = ov::as_type_ptr<PtFrameworkNode>(input_node);
FRONT_END_GENERAL_CHECK(input,
"Input node with index ",
index,
Expand Down Expand Up @@ -327,7 +327,7 @@ Any NodeContext::get_values_from_const_input(int index) const {
if (input_is_none(index))
return {};
auto input_val = get_input_from_visible_context(index);
if (auto input = std::dynamic_pointer_cast<PtFrameworkNode>(input_val.get_node_shared_ptr())) {
if (auto input = ov::as_type_ptr<PtFrameworkNode>(input_val.get_node_shared_ptr())) {
const auto& attrs = input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
return {};
Expand Down
Loading

0 comments on commit 42cb92e

Please sign in to comment.