From f39a238487e7f96a12556a66a3530d5c168c9d95 Mon Sep 17 00:00:00 2001 From: SeungHui Lee Date: Wed, 25 Oct 2023 04:32:03 -0400 Subject: [PATCH] [tflchef] Fix tflchef to reflect extype and custom_code This introduces extype and custom_code for custom operators. ONE-DCO-1.0-Signed-off-by: SeungHui Lee --- compiler/tflchef/core/src/CustomOp/AddV2.cpp | 4 +-- compiler/tflchef/core/src/CustomOp/All.cpp | 4 +-- .../core/src/CustomOp/BatchMatMulV2.cpp | 4 +-- .../tflchef/core/src/CustomOp/BroadcastTo.cpp | 4 +-- compiler/tflchef/core/src/CustomOp/Erf.cpp | 4 +-- compiler/tflchef/core/src/CustomOp/MatMul.cpp | 4 +-- .../core/src/CustomOp/MatrixBandPart.cpp | 4 +-- .../core/src/CustomOp/MaxPoolWithArgmax.cpp | 4 +-- compiler/tflchef/core/src/ModelChef.cpp | 31 ++++++++++++------- compiler/tflchef/core/src/OpUtils.cpp | 29 +++++++++++++++++ compiler/tflchef/core/src/OpUtils.h | 29 +++++++++++++++++ compiler/tflchef/tests/custom_erf/test.recipe | 2 ++ 12 files changed, 95 insertions(+), 28 deletions(-) create mode 100644 compiler/tflchef/core/src/OpUtils.cpp create mode 100644 compiler/tflchef/core/src/OpUtils.h diff --git a/compiler/tflchef/core/src/CustomOp/AddV2.cpp b/compiler/tflchef/core/src/CustomOp/AddV2.cpp index 557c20bce11..dd732c31ee2 100644 --- a/compiler/tflchef/core/src/CustomOp/AddV2.cpp +++ b/compiler/tflchef/core/src/CustomOp/AddV2.cpp @@ -16,6 +16,7 @@ */ #include "AddV2.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> AddV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "AddV2"); + check_custom_op_value(operation, "AddV2"); /** * REGISTER_OP("AddV2") diff --git a/compiler/tflchef/core/src/CustomOp/All.cpp b/compiler/tflchef/core/src/CustomOp/All.cpp index bbef5ecaa37..34a838bd437 100644 --- a/compiler/tflchef/core/src/CustomOp/All.cpp +++ b/compiler/tflchef/core/src/CustomOp/All.cpp @@ -16,6 +16,7 @@ */ #include "All.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> AllChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "All"); + check_custom_op_value(operation, "All"); /** * REGISTER_OP("All") diff --git a/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp b/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp index 6d2c5b13b61..ff359b063be 100644 --- a/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp +++ b/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp @@ -16,6 +16,7 @@ */ #include "BatchMatMulV2.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> BatchMatMulV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "BatchMatMulV2"); + check_custom_op_value(operation, "BatchMatMulV2"); /** * REGISTER_OP("BatchMatMulV2") diff --git a/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp b/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp index dd458b376c7..4485f7a57c0 100644 --- a/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp +++ b/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp @@ -16,6 +16,7 @@ */ #include "BroadcastTo.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> BroadcastToChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "BroadcastTo"); + check_custom_op_value(operation, "BroadcastTo"); /** * REGISTER_OP("BroadcastTo") diff --git a/compiler/tflchef/core/src/CustomOp/Erf.cpp b/compiler/tflchef/core/src/CustomOp/Erf.cpp index f611b68e1dc..cc3e91ecc17 100644 --- a/compiler/tflchef/core/src/CustomOp/Erf.cpp +++ b/compiler/tflchef/core/src/CustomOp/Erf.cpp @@ -16,6 +16,7 @@ */ #include "Erf.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> ErfChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "Erf"); + check_custom_op_value(operation, "Erf"); /** * REGISTER_OP("Erf") diff --git a/compiler/tflchef/core/src/CustomOp/MatMul.cpp b/compiler/tflchef/core/src/CustomOp/MatMul.cpp index e7c707d3722..8b1a1b28a21 100644 --- a/compiler/tflchef/core/src/CustomOp/MatMul.cpp +++ b/compiler/tflchef/core/src/CustomOp/MatMul.cpp @@ -16,6 +16,7 @@ */ #include "MatMul.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> MatMulChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "MatMul"); + check_custom_op_value(operation, "MatMul"); /** * REGISTER_OP("MatMul") diff --git a/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp b/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp index b2500322789..fb329d45dec 100644 --- a/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp +++ b/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp @@ -16,6 +16,7 @@ */ #include "MatrixBandPart.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> MatrixBandPartChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "MatrixBandPart"); + check_custom_op_value(operation, "MatrixBandPart"); /** * REGISTER_OP("MatrixBandPart") diff --git a/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp b/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp index 290d3c2cace..8fcbf11cb9d 100644 --- a/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp +++ b/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp @@ -16,6 +16,7 @@ */ #include "MaxPoolWithArgmax.h" +#include "OpUtils.h" #include @@ -28,8 +29,7 @@ flatbuffers::Offset> MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - - assert(operation.type() == "MaxPoolWithArgmax"); + check_custom_op_value(operation, "MaxPoolWithArgmax"); /** * REGISTER_OP("MaxPoolWithArgmax") diff --git a/compiler/tflchef/core/src/ModelChef.cpp b/compiler/tflchef/core/src/ModelChef.cpp index 3afcd232d70..5ac4222bb54 100644 --- a/compiler/tflchef/core/src/ModelChef.cpp +++ b/compiler/tflchef/core/src/ModelChef.cpp @@ -141,10 +141,11 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe) for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) + if ((operation.has_extype() && operation.extype() == "Custom") || + (operation.has_type() && operation.type() == "Custom")) continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -157,10 +158,11 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe) const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) + if ((operation.has_extype() && operation.extype() == "Custom") || + (operation.has_type() && operation.type() == "Custom")) continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -177,9 +179,9 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ std::set customcode_set; for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if ((operation.has_extype() && operation.extype() == "Custom") || + (operation.has_type() && operation.type() == "Custom")) + customcode_set.insert(operation.custom_code()); } // Add ops used in Graphs(subgraphs) @@ -188,9 +190,9 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if ((operation.has_extype() && operation.extype() == "Custom") || + (operation.has_type() && operation.type() == "Custom")) + customcode_set.insert(operation.custom_code()); } } @@ -619,7 +621,11 @@ template std::map cook_graph(const T &graph, { assert(operation.has_type()); - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); + std::string op_type = operation.type(); + if (operation.has_custom_code()) + op_type = operation.custom_code(); + + auto op_chef = op_chef_registry().lookup(op_type).create(&operation); // Create 'inputs' std::vector input_vec = as_dataset(operation.input()).map(lookup).vectorize(); @@ -650,7 +656,8 @@ template std::map cook_graph(const T &graph, // custom operator else { - auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type()); + auto op_it = + std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code()); assert(op_it != custom_code_vec.end()); opcode_index = builtin_code_map.size(); opcode_index += std::distance(custom_code_vec.begin(), op_it); diff --git a/compiler/tflchef/core/src/OpUtils.cpp b/compiler/tflchef/core/src/OpUtils.cpp new file mode 100644 index 00000000000..008f37fb1f0 --- /dev/null +++ b/compiler/tflchef/core/src/OpUtils.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OpUtils.h" + +#include + +void check_custom_op_value(const tflchef::Operation operation, std::string op_type) +{ + if ((operation.has_extype() && operation.extype() == "Custom") || + (operation.has_type() && operation.type() == "Custom")) + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == op_type); + } +} diff --git a/compiler/tflchef/core/src/OpUtils.h b/compiler/tflchef/core/src/OpUtils.h new file mode 100644 index 00000000000..86532a684ac --- /dev/null +++ b/compiler/tflchef/core/src/OpUtils.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @file OpUtils.h + * @brief This header declares various op_utils functions + */ +#ifndef __OPUTILS_H__ +#define __OPUTILS_H__ + +#include +#include + +void check_custom_op_value(const tflchef::Operation operation, std::string op_type); + +#endif // __OPUTILS_H__ diff --git a/compiler/tflchef/tests/custom_erf/test.recipe b/compiler/tflchef/tests/custom_erf/test.recipe index ab093a30e30..417455bd058 100644 --- a/compiler/tflchef/tests/custom_erf/test.recipe +++ b/compiler/tflchef/tests/custom_erf/test.recipe @@ -12,6 +12,8 @@ operation { type: "Erf" input: "ifm" output: "ofm" + custom_code: "Erf" + extype: "Custom" } input: "ifm" output: "ofm"