Skip to content

Commit

Permalink
[tflchef] Fix tflchef to reflect extype and custom_code
Browse files Browse the repository at this point in the history
This introduces extype and custom_code for custom operators.

ONE-DCO-1.0-Signed-off-by: SeungHui Lee <[email protected]>
  • Loading branch information
Seunghui98 committed Oct 25, 2023
1 parent 0b3f914 commit cf59238
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 20 deletions.
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/AddV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ AddV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "AddV2");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "AddV2");
}

/**
* REGISTER_OP("AddV2")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/All.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ AllChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "All");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "All");
}

/**
* REGISTER_OP("All")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ BatchMatMulV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "BatchMatMulV2");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "BatchMatMulV2");
}

/**
* REGISTER_OP("BatchMatMulV2")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ BroadcastToChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "BroadcastTo");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "BroadcastTo");
}

/**
* REGISTER_OP("BroadcastTo")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/Erf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ ErfChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "Erf");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "Erf");
}

/**
* REGISTER_OP("Erf")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ MatMulChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MatMul");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "MatMul");
}

/**
* REGISTER_OP("MatMul")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ MatrixBandPartChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MatrixBandPart");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "MatrixBandPart");
}

/**
* REGISTER_OP("MatrixBandPart")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MaxPoolWithArgmax");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "MaxPoolWithArgmax");
}

/**
* REGISTER_OP("MaxPoolWithArgmax")
Expand Down
27 changes: 15 additions & 12 deletions compiler/tflchef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)

for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if (operation.has_extype() && operation.extype() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -157,10 +157,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if (operation.has_extype() && operation.extype() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -177,9 +177,8 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
std::set<std::string> customcode_set;
for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if (operation.has_extype() && operation.extype() == "Custom")
customcode_set.insert(operation.custom_code());
}

// Add ops used in Graphs(subgraphs)
Expand All @@ -188,9 +187,8 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if (operation.has_extype() && operation.extype() == "Custom")
customcode_set.insert(operation.custom_code());
}
}

Expand Down Expand Up @@ -619,7 +617,11 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
{
assert(operation.has_type());

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
std::string op_type = operation.type();
if (operation.has_custom_code())
op_type = operation.custom_code();

auto op_chef = op_chef_registry().lookup(op_type).create(&operation);

// Create 'inputs'
std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
Expand Down Expand Up @@ -650,7 +652,8 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
// custom operator
else
{
auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type());
auto op_it =
std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code());
assert(op_it != custom_code_vec.end());
opcode_index = builtin_code_map.size();
opcode_index += std::distance(custom_code_vec.begin(), op_it);
Expand Down
2 changes: 2 additions & 0 deletions compiler/tflchef/tests/custom_erf/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ operation {
type: "Erf"
input: "ifm"
output: "ofm"
custom_code: "Erf"
extype: "Custom"
}
input: "ifm"
output: "ofm"

0 comments on commit cf59238

Please sign in to comment.