Skip to content

Commit

Permalink
[ADD] Shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
PiotrKrzem committed Nov 21, 2024
1 parent 900b0fa commit 31e925a
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 11 deletions.
22 changes: 22 additions & 0 deletions src/core/shape_inference/include/identity_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/identity.hpp"

namespace ov {
namespace op {
namespace v16 {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> shape_infer(const Identity* op, const std::vector<TShape>& input_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 1);

const auto& input_shape = input_shapes[0];

return {input_shape};
}
} // namespace v16
} // namespace op
} // namespace ov
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/nodes/identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ Identity::Identity(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
THROW_CPU_NODE_ERR(errorMessage);
}\
}

auto op = as_type_ptr<op::v16::Identity>(op);
auto identity_op = as_type_ptr<op::v16::Identity>(op);

if (is_type<op::v0::Constant>(op->get_input_node_ptr(0))) {
if (is_type<op::v0::Constant>(identity_op->get_input_node_ptr(0))) {
m_const_input = true;
constant = ConstantType::Const; // Node always produces the same output
} else {
Expand All @@ -57,7 +57,7 @@ void Identity::initSupportedPrimitiveDescriptors() {
auto out_prc = getOriginalOutputPrecisionAtPort(0);

if (shape_prc != out_prc) {
THROW_CPU_NODE_ERR("has to have the same dtype for input and output nodes.")
THROW_CPU_NODE_ERR("has to have the same dtype for input and output nodes.");
}

m_out_prc = out_prc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "single_layer_tests/classes/identity.hpp"
#include "identity.hpp"
#include "common_test_utils/node_builders/constant.hpp"

using namespace CPUTestUtils;
Expand Down Expand Up @@ -43,18 +43,18 @@ void IdentityLayerTestCPU::SetUp() {

std::tie(inFmts, outFmts, priority, selectedType) = cpu_params;

updateSelectedType("ref_any", m_output_precision, configuration);
updateSelectedType("ref_any", output_precision, configuration);

std::vector<InputShape> in_shapes;

if (!const_input) {
in_shapes.push_back({{}, {{m_output_shape}}});
in_shapes.push_back({{}, {{output_shape}}});
} else {
in_shapes.push_back({{m_output_shape}, {{m_output_shape}}});
in_shapes.push_back({{output_shape}, {{output_shape}}});
}
init_input_shapes(in_shapes);

const auto data = std::make_shared<ov::op::v0::Parameter>(m_output_precision, m_output_shape);
const auto data = std::make_shared<ov::op::v0::Parameter>(output_precision, output_shape);
data->set_friendly_name("data");

const auto op = std::make_shared<ov::op::v16::Identity>(data);
Expand All @@ -65,7 +65,9 @@ void IdentityLayerTestCPU::SetUp() {

void IdentityLayerTestCPU::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
auto tensor = ov::test::utils::create_and_fill_tensor(function->inputs()[0].get_element_type(), targetInputStaticShapes[0], InputGenerateData());
const auto& func_inputs = function->inputs();
const auto& func_input = func_inputs[0];
auto tensor = ov::test::utils::create_and_fill_tensor(func_input.get_element_type(), targetInputStaticShapes[0], utils::InputGenerateData());
inputs.insert({func_input.get_node_shared_ptr(), tensor});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ static const std::vector<ov::Shape> shapes = {

static const std::vector<ElementType> prc = {
ElementType::f32,
ElementType::f16.
ElementType::f16,
ElementType::bf16,
ElementType::i32,
ElementType::u16,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include "identity_shape_inference.hpp"
#include "utils.hpp"

using namespace ov;
using namespace ov::intel_cpu;

class Identityv14StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v16::Identity> {
protected:
};

TEST_F(Identityv14StaticShapeInferenceTest, Identity_default_ctor) {
op = make_op();
op->set_adjoint(false);

input_shapes = ShapeVector{{2, 2}};
auto output_shapes = shape_inference(op.get(), input_shapes);

EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes[0], StaticShape({2, 2}));
}

TEST_F(Identityv14StaticShapeInferenceTest, Identity_4_4_small_matrix) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
auto Identity = std::make_shared<op::v14::Identity>(data);

input_shapes = ShapeVector{{4, 4}};
auto output_shapes = shape_inference(Identity.get(), input_shapes);
ASSERT_EQ(output_shapes[0], StaticShape({4, 4}));
}

TEST_F(Identityv14StaticShapeInferenceTest, Identity_10_10_big_matrix) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
auto Identity = std::make_shared<op::v14::Identity>(data);

input_shapes = ShapeVector{{10, 10}};
auto output_shapes = shape_inference(Identity.get(), input_shapes);
ASSERT_EQ(output_shapes[0], StaticShape({10, 10}));
}

TEST_F(Identityv14StaticShapeInferenceTest, Identity_10_1_1_keep_batch_when_single_cell_matrix) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
auto Identity = std::make_shared<op::v14::Identity>(data);

input_shapes = ShapeVector{{10, 1, 1}};
auto output_shapes = shape_inference(Identity.get(), input_shapes);
ASSERT_EQ(output_shapes[0], StaticShape({10, 1, 1}));
}

TEST_F(Identityv14StaticShapeInferenceTest, Identity_10_9_9_keep_batch_big_matrix) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
auto Identity = std::make_shared<op::v14::Identity>(data);

input_shapes = ShapeVector{{10, 9, 9}};
auto output_shapes = shape_inference(Identity.get(), input_shapes);
ASSERT_EQ(output_shapes[0], StaticShape({10, 9, 9}));
}

TEST_F(Identityv14StaticShapeInferenceTest, Identity_10_5_3_2_2_complex_multi_dim_matrix) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
auto Identity = std::make_shared<op::v14::Identity>(data);

input_shapes = ShapeVector{{10, 5, 3, 2, 2}};
auto output_shapes = shape_inference(Identity.get(), input_shapes);
ASSERT_EQ(output_shapes[0], StaticShape({10, 5, 3, 2, 2}));
}

0 comments on commit 31e925a

Please sign in to comment.