diff --git a/compiler/circlechef/core/src/ModelChef.cpp b/compiler/circlechef/core/src/ModelChef.cpp index 6c5206dfc3d..b491a4fe893 100644 --- a/compiler/circlechef/core/src/ModelChef.cpp +++ b/compiler/circlechef/core/src/ModelChef.cpp @@ -135,10 +135,11 @@ gather_builtincode_map(const ::circlechef::ModelRecipe &model_recipe) for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) + if (operation.has_extype() && operation.extype() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); + // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -151,10 +152,10 @@ gather_builtincode_map(const ::circlechef::ModelRecipe &model_recipe) const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) + if (operation.has_extype() && operation.extype() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -171,9 +172,8 @@ std::set gather_customcode_set(const ::circlechef::ModelRecipe &mod std::set customcode_set; for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if (operation.has_extype() && operation.extype() == "Custom") + customcode_set.insert(operation.custom_code()); } // Add ops used in Graphs(subgraphs) @@ -182,9 +182,8 @@ std::set gather_customcode_set(const ::circlechef::ModelRecipe &mod const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == circle::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if (operation.has_extype() && operation.extype() == "Custom") + customcode_set.insert(operation.custom_code()); } } @@ -418,7 +417,11 @@ template void cook_graph(const T &graph, CookParams &cp) { assert(operation.has_type()); - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); + std::string op_type = operation.type(); + if (operation.has_custom_code()) + op_type = operation.custom_code(); + + auto op_chef = op_chef_registry().lookup(op_type).create(&operation); // Create 'inputs' std::vector input_vec = as_dataset(operation.input()).map(lookup).vectorize(); diff --git a/compiler/circlechef/proto/circlechef.proto b/compiler/circlechef/proto/circlechef.proto index d5e08576f0b..c2db55a9d84 100644 --- a/compiler/circlechef/proto/circlechef.proto +++ b/compiler/circlechef/proto/circlechef.proto @@ -91,7 +91,9 @@ message Operation { repeated string input = 2; repeated string output = 3; optional int32 version = 4 [default = 1]; + optional string custom_code = 5; + optional string extype = 99; optional BatchMatMulOptions batch_matmul_options = 100; optional InstanceNormOptions instance_norm_options = 101; optional BCQFullyConnectedOptions bcq_fully_connected_options = 102; diff --git a/compiler/tflchef/core/src/CustomOp/AddV2.cpp b/compiler/tflchef/core/src/CustomOp/AddV2.cpp index 557c20bce11..151a01025bc 100644 --- a/compiler/tflchef/core/src/CustomOp/AddV2.cpp +++ b/compiler/tflchef/core/src/CustomOp/AddV2.cpp @@ -29,7 +29,11 @@ AddV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "AddV2"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "AddV2"); + } /** * REGISTER_OP("AddV2") diff --git a/compiler/tflchef/core/src/CustomOp/All.cpp b/compiler/tflchef/core/src/CustomOp/All.cpp index bbef5ecaa37..7e33c6ffb1a 100644 --- a/compiler/tflchef/core/src/CustomOp/All.cpp +++ b/compiler/tflchef/core/src/CustomOp/All.cpp @@ -29,7 +29,11 @@ AllChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "All"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "All"); + } /** * REGISTER_OP("All") diff --git a/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp b/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp index 6d2c5b13b61..24833c5f48b 100644 --- a/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp +++ b/compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp @@ -29,7 +29,11 @@ BatchMatMulV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "BatchMatMulV2"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "BatchMatMulV2"); + } /** * REGISTER_OP("BatchMatMulV2") diff --git a/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp b/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp index dd458b376c7..b39c7d666d2 100644 --- a/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp +++ b/compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp @@ -29,7 +29,11 @@ BroadcastToChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "BroadcastTo"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "BroadcastTo"); + } /** * REGISTER_OP("BroadcastTo") diff --git a/compiler/tflchef/core/src/CustomOp/Erf.cpp b/compiler/tflchef/core/src/CustomOp/Erf.cpp index f611b68e1dc..6393a5bd3b0 100644 --- a/compiler/tflchef/core/src/CustomOp/Erf.cpp +++ b/compiler/tflchef/core/src/CustomOp/Erf.cpp @@ -29,7 +29,11 @@ ErfChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "Erf"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "Erf"); + } /** * REGISTER_OP("Erf") diff --git a/compiler/tflchef/core/src/CustomOp/MatMul.cpp b/compiler/tflchef/core/src/CustomOp/MatMul.cpp index e7c707d3722..5f5f2c048db 100644 --- a/compiler/tflchef/core/src/CustomOp/MatMul.cpp +++ b/compiler/tflchef/core/src/CustomOp/MatMul.cpp @@ -29,7 +29,11 @@ MatMulChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "MatMul"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "MatMul"); + } /** * REGISTER_OP("MatMul") diff --git a/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp b/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp index b2500322789..46ea5c0d574 100644 --- a/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp +++ b/compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp @@ -29,7 +29,11 @@ MatrixBandPartChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "MatrixBandPart"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "MatrixBandPart"); + } /** * REGISTER_OP("MatrixBandPart") diff --git a/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp b/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp index 290d3c2cace..a2c0ed2e79e 100644 --- a/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp +++ b/compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp @@ -29,7 +29,11 @@ MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const { auto &operation = (*_operation); - assert(operation.type() == "MaxPoolWithArgmax"); + if (operation.has_extype() && operation.extype() == "Custom") + { + assert(operation.has_custom_code()); + assert(operation.custom_code() == "MaxPoolWithArgmax"); + } /** * REGISTER_OP("MaxPoolWithArgmax") diff --git a/compiler/tflchef/core/src/ModelChef.cpp b/compiler/tflchef/core/src/ModelChef.cpp index 3afcd232d70..128f818ed4e 100644 --- a/compiler/tflchef/core/src/ModelChef.cpp +++ b/compiler/tflchef/core/src/ModelChef.cpp @@ -141,10 +141,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe) for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) + if (operation.has_extype() && operation.extype() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -157,10 +157,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe) const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) + if (operation.has_extype() && operation.extype() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -177,9 +177,8 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ std::set customcode_set; for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if (operation.has_extype() && operation.extype() == "Custom") + customcode_set.insert(operation.custom_code()); } // Add ops used in Graphs(subgraphs) @@ -188,9 +187,8 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if (operation.has_extype() && operation.extype() == "Custom") + customcode_set.insert(operation.custom_code()); } } @@ -619,7 +617,11 @@ template std::map cook_graph(const T &graph, { assert(operation.has_type()); - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); + std::string op_type = operation.type(); + if (operation.has_custom_code()) + op_type = operation.custom_code(); + + auto op_chef = op_chef_registry().lookup(op_type).create(&operation); // Create 'inputs' std::vector input_vec = as_dataset(operation.input()).map(lookup).vectorize(); @@ -650,7 +652,8 @@ template std::map cook_graph(const T &graph, // custom operator else { - auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type()); + auto op_it = + std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code()); assert(op_it != custom_code_vec.end()); opcode_index = builtin_code_map.size(); opcode_index += std::distance(custom_code_vec.begin(), op_it); diff --git a/compiler/tflchef/proto/tflchef.proto b/compiler/tflchef/proto/tflchef.proto index 98ae2b23f94..2111415b23a 100644 --- a/compiler/tflchef/proto/tflchef.proto +++ b/compiler/tflchef/proto/tflchef.proto @@ -556,7 +556,9 @@ message Operation { repeated string input = 2; repeated string output = 3; optional int32 version = 4 [default = 1]; + optional string custom_code = 5; + optional string extype = 99; optional Conv2DOptions conv2d_options = 100; optional Pool2DOptions averagepool2d_options = 101; optional ConcatenationOptions concatenation_options = 102; diff --git a/compiler/tflchef/tests/custom_erf/test.recipe b/compiler/tflchef/tests/custom_erf/test.recipe index ab093a30e30..417455bd058 100644 --- a/compiler/tflchef/tests/custom_erf/test.recipe +++ b/compiler/tflchef/tests/custom_erf/test.recipe @@ -12,6 +12,8 @@ operation { type: "Erf" input: "ifm" output: "ofm" + custom_code: "Erf" + extype: "Custom" } input: "ifm" output: "ofm" diff --git a/res/TensorFlowLiteRecipes/All_000/test.recipe b/res/TensorFlowLiteRecipes/All_000/test.recipe index 6e0da1a473d..5a44082149f 100644 --- a/res/TensorFlowLiteRecipes/All_000/test.recipe +++ b/res/TensorFlowLiteRecipes/All_000/test.recipe @@ -25,6 +25,8 @@ operation { input: "ifm" input: "All/reduction_indices" output: "ofm" + custom_code: "All" + extype: "Custom" } input: "ifm" output: "ofm" diff --git a/res/TensorFlowLiteRecipes/BatchMatMulV2_000/test.recipe b/res/TensorFlowLiteRecipes/BatchMatMulV2_000/test.recipe index c6163af924c..c04a6a3b43b 100644 --- a/res/TensorFlowLiteRecipes/BatchMatMulV2_000/test.recipe +++ b/res/TensorFlowLiteRecipes/BatchMatMulV2_000/test.recipe @@ -18,6 +18,8 @@ operation { input: "ifm1" input: "ifm2" output: "ofm" + custom_code: "BatchMatMulV2" + extype: "Custom" } input: "ifm1" input: "ifm2" diff --git a/res/TensorFlowLiteRecipes/BatchMatMulV2_001/test.recipe b/res/TensorFlowLiteRecipes/BatchMatMulV2_001/test.recipe index 9350ca8dc73..1b17396620b 100644 --- a/res/TensorFlowLiteRecipes/BatchMatMulV2_001/test.recipe +++ b/res/TensorFlowLiteRecipes/BatchMatMulV2_001/test.recipe @@ -21,6 +21,8 @@ operation { input: "ifm1" input: "ifm2" output: "ofm" + custom_code: "BatchMatMulV2" + extype: "Custom" } input: "ifm1" input: "ifm2" diff --git a/res/TensorFlowLiteRecipes/BroadcastTo_000/test.recipe b/res/TensorFlowLiteRecipes/BroadcastTo_000/test.recipe index 015e40bc433..0fb9da8bc55 100644 --- a/res/TensorFlowLiteRecipes/BroadcastTo_000/test.recipe +++ b/res/TensorFlowLiteRecipes/BroadcastTo_000/test.recipe @@ -19,6 +19,8 @@ operation { input: "bc_input" input: "bc_shape" output: "bc_ofm" + custom_code: "BroadcastTo" + extype: "Custom" } input: "bc_input" output: "bc_ofm" diff --git a/res/TensorFlowLiteRecipes/MatMul_000/test.recipe b/res/TensorFlowLiteRecipes/MatMul_000/test.recipe index 6e7853aa098..9502cb3b75f 100644 --- a/res/TensorFlowLiteRecipes/MatMul_000/test.recipe +++ b/res/TensorFlowLiteRecipes/MatMul_000/test.recipe @@ -22,6 +22,8 @@ operation { transpose_a: true transpose_b: false } + custom_code: "MatMul" + extype: "Custom" } input: "ifm1" input: "ifm2" diff --git a/res/TensorFlowLiteRecipes/MatrixBandPart_000/test.recipe b/res/TensorFlowLiteRecipes/MatrixBandPart_000/test.recipe index 702aaf8ecc8..7ffdaac9077 100644 --- a/res/TensorFlowLiteRecipes/MatrixBandPart_000/test.recipe +++ b/res/TensorFlowLiteRecipes/MatrixBandPart_000/test.recipe @@ -32,6 +32,8 @@ operation { input: "MatrixBandPart/num_lower" input: "MatrixBandPart/num_upper" output: "ofm" + custom_code: "MatrixBandPart" + extype: "Custom" } input: "ifm" output: "ofm" diff --git a/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_000/test.recipe b/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_000/test.recipe index 9218c20109e..5dba4979548 100644 --- a/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_000/test.recipe +++ b/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_000/test.recipe @@ -27,6 +27,8 @@ operation { output_type: INT64 include_batch_in_index: false } + custom_code: "MaxPoolWithArgmax" + extype: "Custom" } input: "ifm" output: "ofm" diff --git a/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_001/test.recipe b/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_001/test.recipe index 9c15a7c63b8..cdbd2456076 100644 --- a/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_001/test.recipe +++ b/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_001/test.recipe @@ -27,6 +27,8 @@ operation { output_type: INT32 include_batch_in_index: false } + custom_code: "MaxPoolWithArgmax" + extype: "Custom" } input: "ifm" output: "ofm" diff --git a/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_002/test.recipe b/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_002/test.recipe index 702e0163482..f0d7d4ec1f3 100644 --- a/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_002/test.recipe +++ b/res/TensorFlowLiteRecipes/MaxPoolWithArgmax_002/test.recipe @@ -27,6 +27,8 @@ operation { output_type: INT64 include_batch_in_index: false } + custom_code: "MaxPoolWithArgmax" + extype: "Custom" } input: "ifm" output: "ofm" diff --git a/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_000/test.recipe b/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_000/test.recipe index 5069aac0996..09c1106729a 100644 --- a/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_000/test.recipe +++ b/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_000/test.recipe @@ -19,6 +19,8 @@ operation { input: "bc_input" input: "bc_shape" output: "bc_ofm" + custom_code: "BroadcastTo" + extype: "Custom" } operand { name: "reshape_data" @@ -57,6 +59,8 @@ operation { input: "bc_ofm" input: "reshape_ofm" output: "ofm" + custom_code: "AddV2" + extype: "Custom" } input: "bc_input" input: "reshape_data" diff --git a/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_001/test.recipe b/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_001/test.recipe index ca0ad8e03bd..0dd9f8b8525 100644 --- a/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_001/test.recipe +++ b/res/TensorFlowLiteRecipes/Net_BroadcastTo_AddV2_001/test.recipe @@ -19,6 +19,8 @@ operation { input: "bc_input" input: "bc_shape" output: "bc_ofm" + custom_code: "BroadcastTo" + extype: "Custom" } operand { name: "reshape_data" @@ -53,10 +55,12 @@ operand { shape { dim: 1 dim: 2 dim: 3 } } operation { - type: "AddV2" + type: "BroadcastTo" input: "bc_ofm" input: "reshape_ofm" output: "ofm" + custom_code: "AddV2" + extype: "Custom" } input: "bc_input" input: "reshape_data" diff --git a/res/TensorFlowLiteRecipes/Net_FC_Gelu_FC_000/test.recipe b/res/TensorFlowLiteRecipes/Net_FC_Gelu_FC_000/test.recipe index 9e53d18f9e1..703f450dbf5 100644 --- a/res/TensorFlowLiteRecipes/Net_FC_Gelu_FC_000/test.recipe +++ b/res/TensorFlowLiteRecipes/Net_FC_Gelu_FC_000/test.recipe @@ -104,6 +104,8 @@ operation { type: "Erf" input: "fc2" output: "erf" + custom_code: "Erf" + extype: "Custom" } operation { type: "Add" diff --git a/res/TensorFlowLiteRecipes/Net_Gather_SparseToDense_AddV2_000/test.recipe b/res/TensorFlowLiteRecipes/Net_Gather_SparseToDense_AddV2_000/test.recipe index 804d293fc83..326836fd1ab 100644 --- a/res/TensorFlowLiteRecipes/Net_Gather_SparseToDense_AddV2_000/test.recipe +++ b/res/TensorFlowLiteRecipes/Net_Gather_SparseToDense_AddV2_000/test.recipe @@ -103,6 +103,8 @@ operation { input: "ofm_sparse" input: "add_v2_2" output: "ofm_add_v2" + custom_code: "AddV2" + extype: "Custom" } operation { type: "Cast" diff --git a/res/TensorFlowLiteRecipes/Net_Gelu_000/test.recipe b/res/TensorFlowLiteRecipes/Net_Gelu_000/test.recipe index ae7f823e8b9..6285bc1c81a 100644 --- a/res/TensorFlowLiteRecipes/Net_Gelu_000/test.recipe +++ b/res/TensorFlowLiteRecipes/Net_Gelu_000/test.recipe @@ -68,6 +68,8 @@ operation { type: "Erf" input: "mul_sqrt" output: "erf" + custom_code: "Erf" + extype: "Custom" } operation { type: "Add" diff --git a/res/TensorFlowLiteRecipes/Net_Gelu_001/test.recipe b/res/TensorFlowLiteRecipes/Net_Gelu_001/test.recipe index 76337293a3c..eddf6c62de7 100644 --- a/res/TensorFlowLiteRecipes/Net_Gelu_001/test.recipe +++ b/res/TensorFlowLiteRecipes/Net_Gelu_001/test.recipe @@ -68,6 +68,8 @@ operation { type: "Erf" input: "mul_sqrt" output: "erf" + custom_code: "Erf" + extype: "Custom" } operation { type: "Add"