Skip to content

Commit

Permalink
[tflchef] Use extype and custom_code type for custom operators
Browse files Browse the repository at this point in the history
This commit includes replacing op type with custom_code if op_type is Custom or extype is Custom.

ONE-DCO-1.0-Signed-off-by: SeungHui Lee <[email protected]>
  • Loading branch information
Seunghui98 committed Oct 31, 2023
1 parent 6eb9ec4 commit 3fae109
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
36 changes: 24 additions & 12 deletions compiler/tflchef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)

for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if ((operation.has_extype() && operation.extype() == "Custom") || operation.type() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -157,10 +157,11 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if ((operation.has_extype() && operation.extype() == "Custom") ||
operation.type() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -177,9 +178,11 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
std::set<std::string> customcode_set;
for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if ((operation.has_extype() && operation.extype() == "Custom") || operation.type() == "Custom")
{
assert(operation.has_custom_code());
customcode_set.insert(operation.custom_code());
}
}

// Add ops used in Graphs(subgraphs)
Expand All @@ -188,9 +191,12 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if ((operation.has_extype() && operation.extype() == "Custom") ||
operation.type() == "Custom")
{
assert(operation.has_custom_code());
customcode_set.insert(operation.custom_code());
}
}
}

Expand Down Expand Up @@ -619,7 +625,11 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
{
assert(operation.has_type());

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
std::string op_type = operation.type();
if (operation.has_custom_code())
op_type = operation.custom_code();

auto op_chef = op_chef_registry().lookup(op_type).create(&operation);

// Create 'inputs'
std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
Expand Down Expand Up @@ -650,7 +660,9 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
// custom operator
else
{
auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type());
assert(operation.has_custom_code());
auto op_it =
std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code());
assert(op_it != custom_code_vec.end());
opcode_index = builtin_code_map.size();
opcode_index += std::distance(custom_code_vec.begin(), op_it);
Expand Down
2 changes: 2 additions & 0 deletions compiler/tflchef/tests/custom_erf/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ operation {
type: "Erf"
input: "ifm"
output: "ofm"
custom_code: "Erf"
extype: "Custom"
}
input: "ifm"
output: "ofm"

0 comments on commit 3fae109

Please sign in to comment.