Skip to content

Commit

Permalink
[Draft][tflchef/circlechef] Use a dedicated type for custom operators
Browse files Browse the repository at this point in the history
On going draft to use a dedicated type for custom operators in the recipe.

ONE-DCO-1.0-Signed-off-by: SeungHui Lee <[email protected]>
  • Loading branch information
Seunghui98 committed Oct 25, 2023
1 parent 9dcd0be commit 6950305
Show file tree
Hide file tree
Showing 28 changed files with 110 additions and 32 deletions.
25 changes: 14 additions & 11 deletions compiler/circlechef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ gather_builtincode_map(const ::circlechef::ModelRecipe &model_recipe)

for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == circle::BuiltinOperator_CUSTOM)
if (operation.has_extype() && operation.extype() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);

// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -151,10 +152,10 @@ gather_builtincode_map(const ::circlechef::ModelRecipe &model_recipe)
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == circle::BuiltinOperator_CUSTOM)
if (operation.has_extype() && operation.extype() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -171,9 +172,8 @@ std::set<std::string> gather_customcode_set(const ::circlechef::ModelRecipe &mod
std::set<std::string> customcode_set;
for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == circle::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if (operation.has_extype() && operation.extype() == "Custom")
customcode_set.insert(operation.custom_code());
}

// Add ops used in Graphs(subgraphs)
Expand All @@ -182,9 +182,8 @@ std::set<std::string> gather_customcode_set(const ::circlechef::ModelRecipe &mod
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == circle::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if (operation.has_extype() && operation.extype() == "Custom")
customcode_set.insert(operation.custom_code());
}
}

Expand Down Expand Up @@ -418,7 +417,11 @@ template <typename T> void cook_graph(const T &graph, CookParams &cp)
{
assert(operation.has_type());

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
std::string op_type = operation.type();
if (operation.has_custom_code())
op_type = operation.custom_code();

auto op_chef = op_chef_registry().lookup(op_type).create(&operation);

// Create 'inputs'
std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
Expand Down
2 changes: 2 additions & 0 deletions compiler/circlechef/proto/circlechef.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ message Operation {
repeated string input = 2;
repeated string output = 3;
optional int32 version = 4 [default = 1];
optional string custom_code = 5;

optional string extype = 99;
optional BatchMatMulOptions batch_matmul_options = 100;
optional InstanceNormOptions instance_norm_options = 101;
optional BCQFullyConnectedOptions bcq_fully_connected_options = 102;
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/AddV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ AddV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "AddV2");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "AddV2");
}

/**
* REGISTER_OP("AddV2")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/All.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ AllChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "All");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "All");
}

/**
* REGISTER_OP("All")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/BatchMatMulV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ BatchMatMulV2Chef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "BatchMatMulV2");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "BatchMatMulV2");
}

/**
* REGISTER_OP("BatchMatMulV2")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/BroadcastTo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ BroadcastToChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "BroadcastTo");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "BroadcastTo");
}

/**
* REGISTER_OP("BroadcastTo")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/Erf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ ErfChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "Erf");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "Erf");
}

/**
* REGISTER_OP("Erf")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ MatMulChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MatMul");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "MatMul");
}

/**
* REGISTER_OP("MatMul")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/MatrixBandPart.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ MatrixBandPartChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MatrixBandPart");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "MatrixBandPart");
}

/**
* REGISTER_OP("MatrixBandPart")
Expand Down
6 changes: 5 additions & 1 deletion compiler/tflchef/core/src/CustomOp/MaxPoolWithArgmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ MaxPoolWithArgmaxChef::custom_value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);

assert(operation.type() == "MaxPoolWithArgmax");
if (operation.has_extype() && operation.extype() == "Custom")
{
assert(operation.has_custom_code());
assert(operation.custom_code() == "MaxPoolWithArgmax");
}

/**
* REGISTER_OP("MaxPoolWithArgmax")
Expand Down
27 changes: 15 additions & 12 deletions compiler/tflchef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)

for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if (operation.has_extype() && operation.extype() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -157,10 +157,10 @@ gather_builtincode_map(const ::tflchef::ModelRecipe &model_recipe)
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
if (operation.has_extype() && operation.extype() == "Custom")
continue;

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
// Various operation version is unified as the highest version among them
if (builtin_map.find(op_chef->code()) == builtin_map.end() ||
builtin_map[op_chef->code()] < operation.version())
Expand All @@ -177,9 +177,8 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
std::set<std::string> customcode_set;
for (const auto &operation : model_recipe.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if (operation.has_extype() && operation.extype() == "Custom")
customcode_set.insert(operation.custom_code());
}

// Add ops used in Graphs(subgraphs)
Expand All @@ -188,9 +187,8 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
const auto &graph = model_recipe.graph(g);
for (const auto &operation : graph.operation())
{
auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
if (op_chef->code() == tflite::BuiltinOperator_CUSTOM)
customcode_set.insert(operation.type());
if (operation.has_extype() && operation.extype() == "Custom")
customcode_set.insert(operation.custom_code());
}
}

Expand Down Expand Up @@ -619,7 +617,11 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
{
assert(operation.has_type());

auto op_chef = op_chef_registry().lookup(operation.type()).create(&operation);
std::string op_type = operation.type();
if (operation.has_custom_code())
op_type = operation.custom_code();

auto op_chef = op_chef_registry().lookup(op_type).create(&operation);

// Create 'inputs'
std::vector<int32_t> input_vec = as_dataset(operation.input()).map(lookup).vectorize();
Expand Down Expand Up @@ -650,7 +652,8 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
// custom operator
else
{
auto op_it = std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.type());
auto op_it =
std::find(custom_code_vec.begin(), custom_code_vec.end(), operation.custom_code());
assert(op_it != custom_code_vec.end());
opcode_index = builtin_code_map.size();
opcode_index += std::distance(custom_code_vec.begin(), op_it);
Expand Down
2 changes: 2 additions & 0 deletions compiler/tflchef/proto/tflchef.proto
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,9 @@ message Operation {
repeated string input = 2;
repeated string output = 3;
optional int32 version = 4 [default = 1];
optional string custom_code = 5;

optional string extype = 99;
optional Conv2DOptions conv2d_options = 100;
optional Pool2DOptions averagepool2d_options = 101;
optional ConcatenationOptions concatenation_options = 102;
Expand Down
2 changes: 2 additions & 0 deletions compiler/tflchef/tests/custom_erf/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ operation {
type: "Erf"
input: "ifm"
output: "ofm"
custom_code: "Erf"
extype: "Custom"
}
input: "ifm"
output: "ofm"
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/All_000/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ operation {
input: "ifm"
input: "All/reduction_indices"
output: "ofm"
custom_code: "All"
extype: "Custom"
}
input: "ifm"
output: "ofm"
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/BatchMatMulV2_000/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ operation {
input: "ifm1"
input: "ifm2"
output: "ofm"
custom_code: "BatchMatMulV2"
extype: "Custom"
}
input: "ifm1"
input: "ifm2"
Expand Down
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/BatchMatMulV2_001/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ operation {
input: "ifm1"
input: "ifm2"
output: "ofm"
custom_code: "BatchMatMulV2"
extype: "Custom"
}
input: "ifm1"
input: "ifm2"
Expand Down
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/BroadcastTo_000/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ operation {
input: "bc_input"
input: "bc_shape"
output: "bc_ofm"
custom_code: "BroadcastTo"
extype: "Custom"
}
input: "bc_input"
output: "bc_ofm"
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/MatMul_000/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ operation {
transpose_a: true
transpose_b: false
}
custom_code: "MatMul"
extype: "Custom"
}
input: "ifm1"
input: "ifm2"
Expand Down
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/MatrixBandPart_000/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ operation {
input: "MatrixBandPart/num_lower"
input: "MatrixBandPart/num_upper"
output: "ofm"
custom_code: "MatrixBandPart"
extype: "Custom"
}
input: "ifm"
output: "ofm"
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/MaxPoolWithArgmax_000/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ operation {
output_type: INT64
include_batch_in_index: false
}
custom_code: "MaxPoolWithArgmax"
extype: "Custom"
}
input: "ifm"
output: "ofm"
Expand Down
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/MaxPoolWithArgmax_001/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ operation {
output_type: INT32
include_batch_in_index: false
}
custom_code: "MaxPoolWithArgmax"
extype: "Custom"
}
input: "ifm"
output: "ofm"
Expand Down
2 changes: 2 additions & 0 deletions res/TensorFlowLiteRecipes/MaxPoolWithArgmax_002/test.recipe
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ operation {
output_type: INT64
include_batch_in_index: false
}
custom_code: "MaxPoolWithArgmax"
extype: "Custom"
}
input: "ifm"
output: "ofm"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ operation {
input: "bc_input"
input: "bc_shape"
output: "bc_ofm"
custom_code: "BroadcastTo"
extype: "Custom"
}
operand {
name: "reshape_data"
Expand Down Expand Up @@ -57,6 +59,8 @@ operation {
input: "bc_ofm"
input: "reshape_ofm"
output: "ofm"
custom_code: "AddV2"
extype: "Custom"
}
input: "bc_input"
input: "reshape_data"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ operation {
input: "bc_input"
input: "bc_shape"
output: "bc_ofm"
custom_code: "BroadcastTo"
extype: "Custom"
}
operand {
name: "reshape_data"
Expand Down Expand Up @@ -53,10 +55,12 @@ operand {
shape { dim: 1 dim: 2 dim: 3 }
}
operation {
type: "AddV2"
type: "BroadcastTo"
input: "bc_ofm"
input: "reshape_ofm"
output: "ofm"
custom_code: "AddV2"
extype: "Custom"
}
input: "bc_input"
input: "reshape_data"
Expand Down
Loading

0 comments on commit 6950305

Please sign in to comment.