Skip to content

Commit

Permalink
[tflchef] Fix tflchef to reflect extype and custom_code
Browse files Browse the repository at this point in the history
This introduces extype and custom_code for custom operators.

ONE-DCO-1.0-Signed-off-by: SeungHui Lee <[email protected]>
  • Loading branch information
Seunghui98 committed Oct 30, 2023
1 parent 0b3f914 commit f39a238
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 28 deletions.
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/AddV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "AddV2.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
AddV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "AddV2");
check_custom_op_value(operation, "AddV2");

/**
* REGISTER_OP("AddV2")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/All.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "All.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
AllChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "All");
check_custom_op_value(operation, "All");

/**
* REGISTER_OP("All")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "BatchMatMulV2.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
BatchMatMulV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "BatchMatMulV2");
check_custom_op_value(operation, "BatchMatMulV2");

/**
* REGISTER_OP("BatchMatMulV2")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "BroadcastTo.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
BroadcastToChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "BroadcastTo");
check_custom_op_value(operation, "BroadcastTo");

/**
* REGISTER_OP("BroadcastTo")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/Erf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "Erf.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
ErfChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "Erf");
check_custom_op_value(operation, "Erf");

/**
* REGISTER_OP("Erf")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "MatMul.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
MatMulChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MatMul");
check_custom_op_value(operation, "MatMul");

/**
* REGISTER_OP("MatMul")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "MatrixBandPart.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
MatrixBandPartChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MatrixBandPart");
check_custom_op_value(operation, "MatrixBandPart");

/**
* REGISTER_OP("MatrixBandPart")
Expand Down
4 changes: 2 additions & 2 deletions compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "MaxPoolWithArgmax.h"
#include "OpUtils.h"

#include <flatbuffers/flexbuffers.h>

Expand All @@ -28,8 +29,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>
MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MaxPoolWithArgmax");
check_custom_op_value(operation, "MaxPoolWithArgmax");

/**
* REGISTER_OP("MaxPoolWithArgmax")
Expand Down
31 changes: 19 additions & 12 deletions compiler/tflchef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,11 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)

for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if ((operation.has_extype() && operation.extype() == "Custom") ||
(operation.has_type() && operation.type() == "Custom"))
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -157,10 +158,11 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if ((operation.has_extype() && operation.extype() == "Custom") ||
(operation.has_type() && operation.type() == "Custom"))
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -177,9 +179,9 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
std::set<std::string> customcode_set;
for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if ((operation.has_extype() && operation.extype() == "Custom") ||
(operation.has_type() && operation.type() == "Custom"))
customcode_set.insert(operation.custom_code());
}

// Add ops used in Graphs(subgraphs)
Expand All @@ -188,9 +190,9 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if ((operation.has_extype() && operation.extype() == "Custom") ||
(operation.has_type() && operation.type() == "Custom"))
customcode_set.insert(operation.custom_code());
}
}

Expand Down Expand Up @@ -619,7 +621,11 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
{
assert(operation.has_type());

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
std::string op_type = operation.type();
if (operation.has_custom_code())
op_type = operation.custom_code();

auto op_chef = op_chef_registry().lookup(op_type).create(&operation);

// Create 'inputs'
std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
Expand Down Expand Up @@ -650,7 +656,8 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
// custom operator
else
{
auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type());
auto op_it =
std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code());
assert(op_it != custom_code_vec.end());
opcode_index = builtin_code_map.size();
opcode_index += std::distance(custom_code_vec.begin(), op_it);
Expand Down
29 changes: 29 additions & 0 deletions compiler/tflchef/core/src/OpUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "OpUtils.h"

#include <stdexcept>

void check_custom_op_value(const tflchef::Operation operation, std::string op_type)
{
if ((operation.has_extype() && operation.extype() == "Custom") ||
(operation.has_type() && operation.type() == "Custom"))
{
assert(operation.has_custom_code());
assert(operation.custom_code() == op_type);
}
}
29 changes: 29 additions & 0 deletions compiler/tflchef/core/src/OpUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* @file OpUtils.h
* @brief This header declares various op_utils functions
*/
#ifndef __OPUTILS_H__
#define __OPUTILS_H__

#include <tflchef.pb.h>
#include <mio/tflite/schema_generated.h>

void check_custom_op_value(const tflchef::Operation operation, std::string op_type);

#endif // __OPUTILS_H__
2 changes: 2 additions & 0 deletions compiler/tflchef/tests/custom_erf/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ operation {
type: "Erf"
input: "ifm"
output: "ofm"
custom_code: "Erf"
extype: "Custom"
}
input: "ifm"
output: "ofm"

0 comments on commit f39a238

Please sign in to comment.