diff --git a/src/frontends/tensorflow_common/src/op/select.cpp b/src/frontends/tensorflow_common/src/op/select.cpp index f19e01f5a021e6..35c7e893e542e1 100644 --- a/src/frontends/tensorflow_common/src/op/select.cpp +++ b/src/frontends/tensorflow_common/src/op/select.cpp @@ -13,6 +13,7 @@ #include "openvino/op/shape_of.hpp" #include "openvino/op/squeeze.hpp" #include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" using namespace std; using namespace ov; @@ -31,7 +32,19 @@ OutputVector translate_select_base_op(const NodeContext& node, set_node_name(node.get_name(), select); return {select}; } - +bool has_complex_inputs(Output& x, Output& y, element::Type& complex_part_type) { + auto complex_type_mark_x = as_type_ptr(x.get_node_shared_ptr()); + auto complex_type_mark_y = as_type_ptr(y.get_node_shared_ptr()); + if (complex_type_mark_x) { + x = complex_type_mark_x->input_value(0); + complex_part_type = complex_type_mark_x->get_complex_part_type(); + } + if (complex_type_mark_y) { + y = complex_type_mark_y->input_value(0); + complex_part_type = complex_type_mark_y->get_complex_part_type(); + } + return (complex_type_mark_x || complex_type_mark_y); +} OutputVector translate_select_v2_op(const NodeContext& node) { // according to the TensorFlow documentation. See in the code: // https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/lite/kernels/select.cc#L188-L211 @@ -40,10 +53,23 @@ OutputVector translate_select_v2_op(const NodeContext& node) { // is true or the value of 'y' if false. There are valid condition input sizes: // 1. Either the same shape (in which case the select is elementwise), or // 2. Broadcastable shapes between 'condition', 'x' and 'y'. - default_op_checks(node, 3, {"SelectV2", "SELECT_V2"}); - // no preparation for inputs are needed - // inputs are already NumPy broadcastable - return translate_select_base_op(node, node.get_input(0), node.get_input(1), node.get_input(2)); + default_op_checks(node, 3, {"SelectV2", "SELECT_V2"}, true); + auto condition = node.get_input(0); + auto x = node.get_input(1); + auto y = node.get_input(2); + + element::Type complex_part_type; + auto is_complex = has_complex_inputs(x, y, complex_part_type); + + if (is_complex) { + auto const_negative_one = make_shared(element::i32, Shape{1}, -1); + auto new_condition = make_shared(condition, const_negative_one); + auto result = translate_select_base_op(node, new_condition, x, y); + auto complex_result = make_shared(result[0].get_node_shared_ptr(), complex_part_type); + return {complex_result->output(0)}; + } else { + return translate_select_base_op(node, condition, x, y); + } } OutputVector translate_select_op(const NodeContext& node) { @@ -59,21 +85,9 @@ OutputVector translate_select_op(const NodeContext& node) { auto condition = node.get_input(0); auto x = node.get_input(1); auto y = node.get_input(2); - auto complex_type_mark_x = as_type_ptr(x.get_node_shared_ptr()); - auto complex_type_mark_y = as_type_ptr(y.get_node_shared_ptr()); - auto is_complex = (complex_type_mark_x || complex_type_mark_y); element::Type complex_part_type; - - if (complex_type_mark_x) { - x = complex_type_mark_x->input_value(0); - complex_part_type = complex_type_mark_x->get_complex_part_type(); - } - - if (complex_type_mark_y) { - y = complex_type_mark_y->input_value(0); - complex_part_type = complex_type_mark_y->get_complex_part_type(); - } + auto is_complex = has_complex_inputs(x, y, complex_part_type); // compute number of dimensions to unsqueeze the condition auto cond_rank = compute_subgraph_scalar_rank(condition, element::i32); @@ -85,14 +99,13 @@ OutputVector translate_select_op(const NodeContext& node) { auto new_subshape = make_shared(const_one, num_new_axes); auto cond_shape = make_shared(condition, element::i32); // use extra dimensions in the begin to avoid concatenation of empty tensors that is not supported by Concat - auto const_1 = make_shared(element::i32, Shape{1}, 1); - auto new_cond_shape = make_shared(OutputVector{const_1, cond_shape, new_subshape}, 0); + auto new_cond_shape = make_shared(OutputVector{const_one, cond_shape, new_subshape}, 0); // prepare the condition to have the same rank as operands `x` and `y` auto prep_cond = make_shared(condition, new_cond_shape, false)->output(0); // squeeze prep_cond by one extra dimension specially added - auto const_0 = make_shared(element::i32, Shape{1}, 0); - prep_cond = make_shared(prep_cond, const_0); + auto const_zero = make_shared(element::i32, Shape{1}, 0); + prep_cond = make_shared(prep_cond, const_zero); auto result = translate_select_base_op(node, prep_cond, x, y); if (is_complex) { diff --git a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py index 058f2e21a4a60b..d199275bf34345 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_SelectV2.py @@ -51,3 +51,52 @@ def test_select_v2_basic(self, params, ie_device, precision, ir_version, temp_di self._test(*self.create_select_v2_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) + + +class TestComplexSelectV2(CommonTFLayerTest): + def _prepare_input(self, inputs_info): + rng = np.random.default_rng() + assert 'cond:0' in inputs_info, "Test error: inputs_info must contain `cond`" + assert 'x_real:0' in inputs_info, "Test error: inputs_info must contain `x_real`" + assert 'x_imag:0' in inputs_info, "Test error: inputs_info must contain `x_imag`" + assert 'y_real:0' in inputs_info, "Test error: inputs_info must contain `y_real`" + assert 'y_imag:0' in inputs_info, "Test error: inputs_info must contain `y_imag`" + cond_shape = inputs_info['cond:0'] + inputs_data = {} + inputs_data['cond:0'] = np.random.randint(0, 2, cond_shape).astype(bool) + for part in ['x_real:0', 'x_imag:0', 'y_real:0', 'y_imag:0']: + inputs_data[part] = 4 * rng.random(inputs_info[part]).astype(np.float32) - 2 + return inputs_data + + def create_complex_select_v2_net(self, cond_shape, x_shape, y_shape): + tf.compat.v1.reset_default_graph() + # Create the graph and model + with tf.compat.v1.Session() as sess: + cond = tf.compat.v1.placeholder(tf.bool, cond_shape, 'cond') + x_real = tf.compat.v1.placeholder(tf.float32, x_shape, 'x_real') + x_imag = tf.compat.v1.placeholder(tf.float32, x_shape, 'x_imag') + y_real = tf.compat.v1.placeholder(tf.float32, y_shape, 'y_real') + y_imag = tf.compat.v1.placeholder(tf.float32, y_shape, 'y_imag') + complex_x = tf.raw_ops.Complex(real=x_real, imag=x_imag) + complex_y = tf.raw_ops.Complex(real=y_real, imag=y_imag) + complex_select = tf.raw_ops.SelectV2(condition=cond, t=complex_x, e=complex_y) + tf.raw_ops.Real(input=complex_select) + tf.raw_ops.Imag(input=complex_select) + tf.compat.v1.global_variables_initializer() + tf_net = sess.graph_def + return tf_net, None + + test_data_basic = [ + dict(cond_shape=[3, 1], x_shape=[3, 1], y_shape=[3, 1]), + dict(cond_shape=[], x_shape=[2], y_shape=[3, 2]), + dict(cond_shape=[4], x_shape=[3, 2, 1], y_shape=[2, 4]), + ] + + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.precommit + @pytest.mark.nightly + def test_complex_select_v2(self, params, ie_device, precision, ir_version, temp_dir, + use_legacy_frontend): + self._test(*self.create_complex_select_v2_net(**params), + ie_device, precision, ir_version, temp_dir=temp_dir, + use_legacy_frontend=use_legacy_frontend) \ No newline at end of file