diff --git a/compiler/circle-inspect/driver/Driver.cpp b/compiler/circle-inspect/driver/Driver.cpp index 6371261db87..465f4171824 100644 --- a/compiler/circle-inspect/driver/Driver.cpp +++ b/compiler/circle-inspect/driver/Driver.cpp @@ -37,6 +37,7 @@ int entry(int argc, char **argv) arser.add_argument("--constants").nargs(0).help("Dump constant tensors name"); arser.add_argument("--op_version").nargs(0).help("Dump versions of the operators in circle file"); arser.add_argument("--tensor_dtype").nargs(0).help("Dump dtype of tensors"); + arser.add_argument("--tensor_shape").nargs(0).help("Dump shape of tensors"); arser.add_argument("circle").help("Circle file to inspect"); try @@ -51,7 +52,7 @@ int entry(int argc, char **argv) } if (!arser["--operators"] && !arser["--conv2d_weight"] && !arser["--op_version"] && - !arser["--tensor_dtype"] && !arser["--constants"]) + !arser["--tensor_dtype"] && !arser["--constants"] && !arser["--tensor_shape"]) { std::cout << "At least one option must be specified" << std::endl; std::cout << arser; @@ -70,6 +71,8 @@ int entry(int argc, char **argv) dumps.push_back(std::make_unique()); if (arser["--constants"]) dumps.push_back(std::make_unique()); + if (arser["--tensor_shape"]) + dumps.push_back(std::make_unique()); std::string model_file = arser.get("circle"); diff --git a/compiler/circle-inspect/src/Dump.cpp b/compiler/circle-inspect/src/Dump.cpp index 9d363e1ffa9..08238736be9 100644 --- a/compiler/circle-inspect/src/Dump.cpp +++ b/compiler/circle-inspect/src/Dump.cpp @@ -240,3 +240,38 @@ void DumpConstants::run(std::ostream &os, const circle::Model *model, const std: } } // namespace circleinspect + +namespace circleinspect +{ + +void DumpTensorShape::run(std::ostream &os, const circle::Model *model, + const std::vector *data) +{ + mio::circle::Reader reader(model, data); + + const uint32_t subgraph_size = reader.num_subgraph(); + + for (uint32_t g = 0; g < subgraph_size; g++) + { + reader.select_subgraph(g); + auto tensors = reader.tensors(); + + for (uint32_t i = 0; i < tensors->size(); ++i) + { + const auto tensor = tensors->Get(i); + auto shape = tensor->shape_signature() ? tensor->shape_signature() : tensor->shape(); + os << reader.tensor_name(tensor) << " ["; + for (uint32_t i = 0; i < shape->size(); i++) + { + os << shape->Get(i); + if (i != shape->size() - 1) + { + os << ","; + } + } + os << "]" << std::endl; + } + } +} + +} // namespace circleinspect diff --git a/compiler/circle-inspect/src/Dump.h b/compiler/circle-inspect/src/Dump.h index 12a43a71001..4959adcf1db 100644 --- a/compiler/circle-inspect/src/Dump.h +++ b/compiler/circle-inspect/src/Dump.h @@ -78,6 +78,15 @@ class DumpConstants final : public DumpInterface void run(std::ostream &os, const circle::Model *model, const std::vector *data); }; +class DumpTensorShape final : public DumpInterface +{ +public: + DumpTensorShape() = default; + +public: + void run(std::ostream &os, const circle::Model *model, const std::vector *data); +}; + } // namespace circleinspect #endif // __DUMP_H__