From 3fae109abc6d3bf9cf8b5e783bcb619444e1424c Mon Sep 17 00:00:00 2001 From: SeungHui Lee Date: Wed, 25 Oct 2023 04:32:03 -0400 Subject: [PATCH] [tflchef] Use extype and custom_code type for custom operators This commit includes replacing op type with custom_code if op_type is Custom or extype is Custom. ONE-DCO-1.0-Signed-off-by: SeungHui Lee --- compiler/tflchef/core/src/ModelChef.cpp | 36 ++++++++++++------- compiler/tflchef/tests/custom_erf/test.recipe | 2 ++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/compiler/tflchef/core/src/ModelChef.cpp b/compiler/tflchef/core/src/ModelChef.cpp index 3afcd232d70..0e8cdbe40ae 100644 --- a/compiler/tflchef/core/src/ModelChef.cpp +++ b/compiler/tflchef/core/src/ModelChef.cpp @@ -141,10 +141,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe) for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) + if ((operation.has_extype() && operation.extype() == "Custom") || operation.type() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -157,10 +157,11 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe) const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) + if ((operation.has_extype() && operation.extype() == "Custom") || + operation.type() == "Custom") continue; + auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); // Various operation version is unified as the highest version among them if (builtin_map.find(op_chef->code()) == builtin_map.end() || builtin_map[op_chef->code()] < operation.version()) @@ -177,9 +178,11 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ std::set customcode_set; for (const auto &operation : model_recipe.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if ((operation.has_extype() && operation.extype() == "Custom") || operation.type() == "Custom") + { + assert(operation.has_custom_code()); + customcode_set.insert(operation.custom_code()); + } } // Add ops used in Graphs(subgraphs) @@ -188,9 +191,12 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ const auto &graph = model_recipe.graph(g); for (const auto &operation : graph.operation()) { - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); - if (op_chef->code() == tflite::BuiltinOperator_CUSTOM) - customcode_set.insert(operation.type()); + if ((operation.has_extype() && operation.extype() == "Custom") || + operation.type() == "Custom") + { + assert(operation.has_custom_code()); + customcode_set.insert(operation.custom_code()); + } } } @@ -619,7 +625,11 @@ template std::map cook_graph(const T &graph, { assert(operation.has_type()); - auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation); + std::string op_type = operation.type(); + if (operation.has_custom_code()) + op_type = operation.custom_code(); + + auto op_chef = op_chef_registry().lookup(op_type).create(&operation); // Create 'inputs' std::vector input_vec = as_dataset(operation.input()).map(lookup).vectorize(); @@ -650,7 +660,9 @@ template std::map cook_graph(const T &graph, // custom operator else { - auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type()); + assert(operation.has_custom_code()); + auto op_it = + std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code()); assert(op_it != custom_code_vec.end()); opcode_index = builtin_code_map.size(); opcode_index += std::distance(custom_code_vec.begin(), op_it); diff --git a/compiler/tflchef/tests/custom_erf/test.recipe b/compiler/tflchef/tests/custom_erf/test.recipe index ab093a30e30..417455bd058 100644 --- a/compiler/tflchef/tests/custom_erf/test.recipe +++ b/compiler/tflchef/tests/custom_erf/test.recipe @@ -12,6 +12,8 @@ operation { type: "Erf" input: "ifm" output: "ofm" + custom_code: "Erf" + extype: "Custom" } input: "ifm" output: "ofm"