From b34925e739c577fd576fe855bcc2c90107970eb3 Mon Sep 17 00:00:00 2001 From: Meiht <2455220665@qq.com> Date: Wed, 23 Oct 2024 00:01:03 +0800 Subject: [PATCH] update the model parsing code to the latest PNNX format. --- include/runtime/pnnx/ir.h | 422 ++- include/runtime/pnnx/store_zip.hpp | 74 +- source/runtime/pnnx/ir.cpp | 4426 +++++++++++++++++----------- source/runtime/pnnx/store_zip.cpp | 670 +++-- 4 files changed, 3427 insertions(+), 2165 deletions(-) diff --git a/include/runtime/pnnx/ir.h b/include/runtime/pnnx/ir.h index 0e3badb7..57c52ea9 100644 --- a/include/runtime/pnnx/ir.h +++ b/include/runtime/pnnx/ir.h @@ -15,13 +15,16 @@ #ifndef PNNX_IR_H #define PNNX_IR_H +#include +#include #include +#include #include #include #include #include -#if BUILD_PNNX +#if BUILD_TORCH2PNNX namespace torch { namespace jit { struct Value; @@ -31,135 +34,223 @@ struct Node; namespace at { class Tensor; } -#endif // BUILD_PNNX +#endif // BUILD_TORCH2PNNX + +#if BUILD_ONNX2PNNX +namespace onnx { +class AttributeProto; +class TensorProto; +class ValueInfoProto; +} // namespace onnx +namespace pnnx { +namespace onnx2pnnx { +class OnnxAttributeProxy; +} // namespace onnx2pnnx +} // namespace pnnx +#endif // BUILD_ONNX2PNNX namespace pnnx { class Parameter { - public: - Parameter() - : type(0) - { - } - Parameter(bool _b) - : type(1), b(_b) - { - } - Parameter(int _i) - : type(2), i(_i) - { - } - Parameter(long _l) - : type(2), i(_l) - { - } - Parameter(long long _l) - : type(2), i(_l) - { - } - Parameter(float _f) - : type(3), f(_f) - { - } - Parameter(double _d) - : type(3), f(_d) - { - } - Parameter(const char* _s) - : type(4), s(_s) - { - } - Parameter(const std::string& _s) - : type(4), s(_s) - { - } - Parameter(const std::initializer_list& _ai) - : type(5), ai(_ai) - { - } - Parameter(const std::initializer_list& _ai) - : type(5) - { - for (const auto& x : _ai) - ai.push_back((int)x); - } - Parameter(const std::vector& _ai) - : type(5), ai(_ai) - { - } - Parameter(const std::initializer_list& _af) - : type(6), af(_af) - { - } - Parameter(const std::initializer_list& _af) - : type(6) - { - for (const auto& x : _af) - af.push_back((float)x); - } - Parameter(const std::vector& _af) - : type(6), af(_af) - { - } - Parameter(const std::initializer_list& _as) - : type(7) - { - for (const auto& x : _as) - as.push_back(std::string(x)); - } - Parameter(const std::initializer_list& _as) - : type(7), as(_as) - { - } - Parameter(const std::vector& _as) - : type(7), as(_as) - { - } - -#if BUILD_PNNX - Parameter(const torch::jit::Node* value_node); - Parameter(const torch::jit::Value* value); -#endif // BUILD_PNNX - - static Parameter parse_from_string(const std::string& value); - - // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others - int type; - - // value - bool b; - int i; - float f; - std::vector ai; - std::vector af; - - // keep std::string typed member the last for cross cxxabi compatibility - std::string s; - std::vector as; +public: + Parameter() + : type(0) + { + } + Parameter(bool _b) + : type(1), b(_b) + { + } + Parameter(int _i) + : type(2), i(_i) + { + } + Parameter(long _l) + : type(2) + { + if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::min()) _l = INT_MIN; + i = (int)_l; + } + Parameter(long long _l) + : type(2) + { + if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::min()) _l = INT_MIN; + i = (int)_l; + } + Parameter(float _f) + : type(3), f(_f) + { + } + Parameter(double _d) + : type(3), f((float)_d) + { + } + Parameter(const char* _s) + : type(4), s(_s) + { + } + Parameter(const std::string& _s) + : type(4), s(_s) + { + } + Parameter(const std::initializer_list& _ai) + : type(5), ai(_ai) + { + } + Parameter(const std::initializer_list& _ai) + : type(5) + { + for (const auto& x : _ai) + { + int64_t _l = x; + if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::min()) _l = INT_MIN; + ai.push_back((int)_l); + } + } + Parameter(const std::vector& _ai) + : type(5), ai(_ai) + { + } + Parameter(const std::vector& _ai) + : type(5) + { + for (const auto& x : _ai) + { + int64_t _l = x; + if (_l == std::numeric_limits::max()) _l = INT_MAX; + if (_l == std::numeric_limits::min()) _l = INT_MIN; + ai.push_back((int)_l); + } + } + Parameter(const std::initializer_list& _af) + : type(6), af(_af) + { + } + Parameter(const std::initializer_list& _af) + : type(6) + { + for (const auto& x : _af) + af.push_back((float)x); + } + Parameter(const std::vector& _af) + : type(6), af(_af) + { + } + Parameter(const std::vector& _af) + : type(6) + { + for (const auto& x : _af) + af.push_back((float)x); + } + Parameter(const std::initializer_list& _as) + : type(7) + { + for (const auto& x : _as) + as.push_back(std::string(x)); + } + Parameter(const std::initializer_list& _as) + : type(7), as(_as) + { + } + Parameter(const std::vector& _as) + : type(7), as(_as) + { + } + Parameter(const std::complex& _c) + : type(10), c(_c) + { + } + Parameter(const std::complex& _c) + : type(10), c(_c) + { + } + Parameter(const std::initializer_list >& _ac) + : type(11), ac(_ac) + { + } + Parameter(const std::initializer_list >& _ac) + : type(11) + { + for (const auto& x : _ac) + ac.push_back(std::complex(x)); + } + Parameter(const std::vector >& _ac) + : type(11), ac(_ac) + { + } + Parameter(const std::vector >& _ac) + : type(11) + { + for (const auto& x : _ac) + ac.push_back(std::complex(x)); + } + +#if BUILD_TORCH2PNNX + Parameter(const torch::jit::Node* value_node); + Parameter(const torch::jit::Value* value); +#endif // BUILD_TORCH2PNNX +#if BUILD_ONNX2PNNX + Parameter(const onnx::AttributeProto& attr); + Parameter(const onnx2pnnx::OnnxAttributeProxy& attr); +#endif // BUILD_ONNX2PNNX + + static Parameter parse_from_string(const std::string& value); + static std::string encode_to_string(const Parameter& param); + + // 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others 10=c 11=ac + int type; + + // value + bool b; + int i; + float f; + std::complex c; + std::vector ai; + std::vector af; + std::vector > ac; + + // keep std::string typed member the last for cross cxxabi compatibility + std::string s; + std::vector as; }; bool operator==(const Parameter& lhs, const Parameter& rhs); class Attribute { - public: - Attribute() - : type(0) - { - } +public: + Attribute() + : type(0) + { + } + +#if BUILD_TORCH2PNNX + Attribute(const at::Tensor& t); +#endif +#if BUILD_ONNX2PNNX + Attribute(const onnx::TensorProto& t); +#endif + + Attribute(const std::initializer_list& shape, const std::vector& t); -#if BUILD_PNNX - Attribute(const at::Tensor& t); -#endif // BUILD_PNNX + size_t elemsize() const; + int elemcount() const; - Attribute(const std::initializer_list& shape, const std::vector& t); + // convenient routines for manipulate fp32/fp16 weight + std::vector get_float32_data() const; + void set_float32_data(const std::vector& data); - // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool - int type; - std::vector shape; + // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=c64 11=c128 12=c32 13=bf16 + int type; + std::vector shape; - std::vector data; + std::vector data; + + std::map params; }; bool operator==(const Attribute& lhs, const Attribute& rhs); @@ -170,72 +261,91 @@ Attribute operator+(const Attribute& a, const Attribute& b); class Operator; class Operand { - public: - void remove_consumer(const Operator* c); +public: + Operand():type(0){}; + + void remove_consumer(const Operator* c); - Operator* producer; - std::vector consumers; + Operator* producer; + std::vector consumers; - // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=cp64 11=cp128 12=cp32 - int type; - std::vector shape; + // 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=c64 11=c128 12=c32 13=bf16 + int type; + std::vector shape; - // keep std::string typed member the last for cross cxxabi compatibility - std::string name; + // keep std::string typed member the last for cross cxxabi compatibility + std::string name; - std::map params; + std::map params; +private: + friend class Graph; }; class Operator { - public: - std::vector inputs; - std::vector outputs; - - // keep std::string typed member the last for cross cxxabi compatibility - std::string type; - std::string name; - - std::vector inputnames; - std::map params; - std::map attrs; +public: + Operator(){} + bool has_param(const std::string& key) const; + bool has_attr(const std::string& key) const; + bool has_input(const std::string& key) const; + Operand* named_input(const std::string& key); + const Operand* named_input(const std::string& key) const; + + std::vector inputs; + std::vector outputs; + + // keep std::string typed member the last for cross cxxabi compatibility + std::string type; + std::string name; + + std::vector inputnames; + std::map params; + std::map attrs; + +private: + friend class Graph; + }; class Graph { - public: - Graph(); - ~Graph(); +public: + Graph(); + ~Graph(); - int load(const std::string& parampath, const std::string& binpath); - int save(const std::string& parampath, const std::string& binpath); + int load(const std::string& parampath, const std::string& binpath); + int save(const std::string& parampath, const std::string& binpath); - int python(const std::string& pypath, const std::string& binpath); + int python(const std::string& pypath, const std::string& binpath); - int parse(const std::string& param); + int parse(const std::string& param); - Operator* new_operator(const std::string& type, const std::string& name); + Operator* new_operator(const std::string& type, const std::string& name); - Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur); + Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur); - Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur); + Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur); -#if BUILD_PNNX - Operand* new_operand(const torch::jit::Value* v); +#if BUILD_TORCH2PNNX + Operand* new_operand(const torch::jit::Value* v); +#endif +#if BUILD_ONNX2PNNX + Operand* new_operand(const onnx::ValueInfoProto& value); + Operand* new_operand(const onnx::TensorProto& t); #endif - Operand* new_operand(const std::string& name); + Operand* new_operand(const std::string& name); - Operand* get_operand(const std::string& name); - const Operand* get_operand(const std::string& name) const; + Operand* get_operand(const std::string& name); + const Operand* get_operand(const std::string& name) const; - std::vector ops; - std::vector operands; + std::vector ops; + std::vector operands; - private: - Graph(const Graph& rhs); - Graph& operator=(const Graph& rhs); +private: + Graph(const Graph& rhs); + Graph& operator=(const Graph& rhs); }; } // namespace pnnx diff --git a/include/runtime/pnnx/store_zip.hpp b/include/runtime/pnnx/store_zip.hpp index 101b43d7..352d70b2 100644 --- a/include/runtime/pnnx/store_zip.hpp +++ b/include/runtime/pnnx/store_zip.hpp @@ -15,61 +15,67 @@ #ifndef PNNX_STOREZIP_H #define PNNX_STOREZIP_H -#include +#include #include #include #include namespace pnnx { -class StoreZipReader { - public: - StoreZipReader(); - ~StoreZipReader(); +class StoreZipReader +{ +public: + StoreZipReader(); + ~StoreZipReader(); - int open(const std::string& path); + int open(const std::string& path); - size_t get_file_size(const std::string& name); + std::vector get_names() const; - int read_file(const std::string& name, char* data); + uint64_t get_file_size(const std::string& name) const; - int close(); + int read_file(const std::string& name, char* data); - private: - FILE* fp; + int close(); - struct StoreZipMeta { - size_t offset; - size_t size; - }; +private: + FILE* fp; - std::map filemetas; + struct StoreZipMeta + { + uint64_t offset; + uint64_t size; + }; + + std::map filemetas; }; -class StoreZipWriter { - public: - StoreZipWriter(); - ~StoreZipWriter(); +class StoreZipWriter +{ +public: + StoreZipWriter(); + ~StoreZipWriter(); - int open(const std::string& path); + int open(const std::string& path); - int write_file(const std::string& name, const char* data, size_t size); + int write_file(const std::string& name, const char* data, uint64_t size); - int close(); + int close(); - private: - FILE* fp; +private: + FILE* fp; - struct StoreZipMeta { - std::string name; - size_t lfh_offset; - uint32_t crc32; - uint32_t size; - }; + struct StoreZipMeta + { + std::string name; + uint64_t lfh_offset; + uint32_t crc32; + uint64_t size; + }; - std::vector filemetas; + std::vector filemetas; }; -} // namespace pnnx +} // namespace pnnx -#endif // PNNX_STOREZIP_H +#endif // PNNX_STOREZIP_H diff --git a/source/runtime/pnnx/ir.cpp b/source/runtime/pnnx/ir.cpp index 98269a9b..c8b6d7a2 100644 --- a/source/runtime/pnnx/ir.cpp +++ b/source/runtime/pnnx/ir.cpp @@ -3,7 +3,7 @@ // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License slice +// in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // @@ -13,2016 +13,2958 @@ // specific language governing permissions and limitations under the License. #include "runtime/pnnx/ir.h" + #include #include #include #include #include #include -#include #include - -#if BUILD_PNNX -#include -#endif +#include #include "runtime/pnnx/store_zip.hpp" +// #include "utils.h" namespace pnnx { -static bool type_is_integer(int type) { - if (type == 1) return false; - if (type == 2) return false; - if (type == 3) return false; - if (type == 4) return true; - if (type == 5) return true; - if (type == 6) return true; - if (type == 7) return true; - if (type == 8) return true; - if (type == 9) return true; - if (type == 10) return false; - if (type == 11) return false; - if (type == 12) return false; - return false; +unsigned short float32_to_float16(float value) +{ + // 1 : 8 : 23 + union + { + unsigned int u; + float f; + } tmp; + + tmp.f = value; + + // 1 : 8 : 23 + unsigned short sign = (tmp.u & 0x80000000) >> 31; + unsigned short exponent = (tmp.u & 0x7F800000) >> 23; + unsigned int significand = tmp.u & 0x7FFFFF; + + // NCNN_LOGE("%d %d %d", sign, exponent, significand); + + // 1 : 5 : 10 + unsigned short fp16; + if (exponent == 0) + { + // zero or denormal, always underflow + fp16 = (sign << 15) | (0x00 << 10) | 0x00; + } + else if (exponent == 0xFF) + { + // infinity or NaN + fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00); + } + else + { + // normalized + short newexp = exponent + (-127 + 15); + if (newexp >= 31) + { + // overflow, return infinity + fp16 = (sign << 15) | (0x1F << 10) | 0x00; + } + else if (newexp <= 0) + { + // Some normal fp32 cannot be expressed as normal fp16 + fp16 = (sign << 15) | (0x00 << 10) | 0x00; + } + else + { + // normal fp16 + fp16 = (sign << 15) | (newexp << 10) | (significand >> 13); + } + } + + return fp16; } -static const char* type_to_string(int type) { - if (type == 1) return "f32"; - if (type == 2) return "f64"; - if (type == 3) return "f16"; - if (type == 4) return "i32"; - if (type == 5) return "i64"; - if (type == 6) return "i16"; - if (type == 7) return "i8"; - if (type == 8) return "u8"; - if (type == 9) return "bool"; - if (type == 10) return "cp64"; - if (type == 11) return "cp128"; - if (type == 12) return "cp32"; - return "null"; +float float16_to_float32(unsigned short value) +{ + // 1 : 5 : 10 + unsigned short sign = (value & 0x8000) >> 15; + unsigned short exponent = (value & 0x7c00) >> 10; + unsigned short significand = value & 0x03FF; + + // NCNN_LOGE("%d %d %d", sign, exponent, significand); + + // 1 : 8 : 23 + union + { + unsigned int u; + float f; + } tmp; + if (exponent == 0) + { + if (significand == 0) + { + // zero + tmp.u = (sign << 31); + } + else + { + // denormal + exponent = 0; + // find non-zero bit + while ((significand & 0x200) == 0) + { + significand <<= 1; + exponent++; + } + significand <<= 1; + significand &= 0x3FF; + tmp.u = (sign << 31) | ((-exponent + (-15 + 127)) << 23) | (significand << 13); + } + } + else if (exponent == 0x1F) + { + // infinity or NaN + tmp.u = (sign << 31) | (0xFF << 23) | (significand << 13); + } + else + { + // normalized + tmp.u = (sign << 31) | ((exponent + (-15 + 127)) << 23) | (significand << 13); + } + + return tmp.f; } -static const char* type_to_numpy_string(int type) { - if (type == 1) return "float32"; - if (type == 2) return "float64"; - if (type == 3) return "float16"; - if (type == 4) return "int32"; - if (type == 5) return "int64"; - if (type == 6) return "int16"; - if (type == 7) return "int8"; - if (type == 8) return "uint8"; - if (type == 9) return "bool8"; - if (type == 10) return "csingle"; - if (type == 11) return "cdouble"; - if (type == 12) return "chalf"; - return "null"; +static bool type_is_integer(int type) +{ + if (type == 1) return false; + if (type == 2) return false; + if (type == 3) return false; + if (type == 4) return true; + if (type == 5) return true; + if (type == 6) return true; + if (type == 7) return true; + if (type == 8) return true; + if (type == 9) return true; + if (type == 10) return false; + if (type == 11) return false; + if (type == 12) return false; + if (type == 13) return false; + return false; } -static const char* type_to_dtype_string(int type) { - if (type == 1) return "torch.float"; - if (type == 2) return "torch.double"; - if (type == 3) return "torch.half"; - if (type == 4) return "torch.int"; - if (type == 5) return "torch.long"; - if (type == 6) return "torch.short"; - if (type == 7) return "torch.int8"; - if (type == 8) return "torch.uint8"; - if (type == 9) return "torch.bool"; - if (type == 10) return "torch.complex64"; - if (type == 11) return "torch.complex128"; - if (type == 12) return "torch.complex32"; - return "null"; +static const char* type_to_string(int type) +{ + if (type == 1) return "f32"; + if (type == 2) return "f64"; + if (type == 3) return "f16"; + if (type == 4) return "i32"; + if (type == 5) return "i64"; + if (type == 6) return "i16"; + if (type == 7) return "i8"; + if (type == 8) return "u8"; + if (type == 9) return "bool"; + if (type == 10) return "c64"; + if (type == 11) return "c128"; + if (type == 12) return "c32"; + if (type == 13) return "bf16"; + return "null"; } -static size_t type_to_elemsize(int type) { - if (type == 1) return 4; - if (type == 2) return 8; - if (type == 3) return 2; - if (type == 4) return 4; - if (type == 5) return 8; - if (type == 6) return 2; - if (type == 7) return 1; - if (type == 8) return 1; - if (type == 9) return 1; - if (type == 10) return 8; - if (type == 11) return 16; - if (type == 12) return 4; - return 0; // null +static const char* type_to_numpy_string(int type) +{ + if (type == 1) return "float32"; + if (type == 2) return "float64"; + if (type == 3) return "float16"; + if (type == 4) return "int32"; + if (type == 5) return "int64"; + if (type == 6) return "int16"; + if (type == 7) return "int8"; + if (type == 8) return "uint8"; + if (type == 9) return "bool8"; + if (type == 10) return "csingle"; + if (type == 11) return "cdouble"; + if (type == 12) return "chalf"; + if (type == 13) return "bfloat16"; + return "null"; } -static int string_to_type(const char* s) { - if (strcmp(s, "f32") == 0) return 1; - if (strcmp(s, "f64") == 0) return 2; - if (strcmp(s, "f16") == 0) return 3; - if (strcmp(s, "i32") == 0) return 4; - if (strcmp(s, "i64") == 0) return 5; - if (strcmp(s, "i16") == 0) return 6; - if (strcmp(s, "i8") == 0) return 7; - if (strcmp(s, "u8") == 0) return 8; - if (strcmp(s, "bool") == 0) return 9; - if (strcmp(s, "cp64") == 0) return 10; - if (strcmp(s, "cp128") == 0) return 11; - if (strcmp(s, "cp32") == 0) return 12; - return 0; // null +static const char* type_to_dtype_string(int type) +{ + if (type == 1) return "torch.float"; + if (type == 2) return "torch.double"; + if (type == 3) return "torch.half"; + if (type == 4) return "torch.int"; + if (type == 5) return "torch.long"; + if (type == 6) return "torch.short"; + if (type == 7) return "torch.int8"; + if (type == 8) return "torch.uint8"; + if (type == 9) return "torch.bool"; + if (type == 10) return "torch.complex64"; + if (type == 11) return "torch.complex128"; + if (type == 12) return "torch.complex32"; + if (type == 13) return "torch.bfloat16"; + return "null"; } -#if BUILD_PNNX -int get_at_tensor_type(const at::ScalarType& st) { - if (st == c10::ScalarType::Float) return 1; - if (st == c10::ScalarType::Double) return 2; - if (st == c10::ScalarType::Half) return 3; - if (st == c10::ScalarType::Int) return 4; - if (st == c10::ScalarType::QInt32) return 4; - if (st == c10::ScalarType::Long) return 5; - if (st == c10::ScalarType::Short) return 6; - if (st == c10::ScalarType::Char) return 7; - if (st == c10::ScalarType::QInt8) return 7; - if (st == c10::ScalarType::Byte) return 8; - if (st == c10::ScalarType::QUInt8) return 8; - if (st == c10::ScalarType::Bool) return 9; - if (st == c10::ScalarType::ComplexFloat) return 10; - if (st == c10::ScalarType::ComplexDouble) return 11; - if (st == c10::ScalarType::ComplexHalf) return 12; - return 0; // unknown type +static size_t type_to_elemsize(int type) +{ + if (type == 1) return 4; + if (type == 2) return 8; + if (type == 3) return 2; + if (type == 4) return 4; + if (type == 5) return 8; + if (type == 6) return 2; + if (type == 7) return 1; + if (type == 8) return 1; + if (type == 9) return 1; + if (type == 10) return 8; + if (type == 11) return 16; + if (type == 12) return 4; + if (type == 13) return 2; + return 0; // null } -Parameter::Parameter(const torch::jit::Node* value_node) { - type = 0; - - if (value_node->kind() == c10::prim::Constant) { - if (!value_node->hasAttribute(torch::jit::attr::value)) { - fprintf(stderr, "no attribute value\n"); - return; - } - - switch (value_node->output()->type()->kind()) { - case c10::TypeKind::NoneType: { - type = 0; - break; - } - case c10::TypeKind::BoolType: { - type = 1; - b = value_node->i(torch::jit::attr::value); - break; - } - case c10::TypeKind::IntType: { - type = 2; - int64_t i64 = value_node->i(torch::jit::attr::value); - if (i64 == LONG_MAX) i64 = INT_MAX; - if (i64 == LONG_MIN) i64 = INT_MIN; - i = (int)i64; - break; - } - case c10::TypeKind::FloatType: { - type = 3; - f = (float)value_node->f(torch::jit::attr::value); - break; - } - case c10::TypeKind::StringType: { - type = 4; - s = value_node->s(torch::jit::attr::value); - break; - } - case c10::TypeKind::TensorType: { - slice::Tensor t = value_node->t(torch::jit::attr::value); - - if (t.dim() == 0) { - if (t.scalar_type() == c10::ScalarType::Long) { - type = 2; - int64_t i64 = t.item(); - if (i64 == LONG_MAX) i64 = INT_MAX; - if (i64 == LONG_MIN) i64 = INT_MIN; - i = (int)i64; - } else if (t.scalar_type() == c10::ScalarType::Int) { - type = 2; - i = t.item(); - } else if (t.scalar_type() == c10::ScalarType::Double) { - type = 3; - f = (float)t.item(); - } else if (t.scalar_type() == c10::ScalarType::Float) { - type = 3; - f = t.item(); - } else { - fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", - value_node->kind().toDisplayString()); - } - } else { - const int ndim = (int)t.dim(); - - type = 8; - fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = %d\n", - value_node->kind().toDisplayString(), ndim); - } - - break; - } - default: { - fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); - break; - } - } - } else if (value_node->kind() == c10::prim::ListConstruct) { - switch (value_node->output()->type()->cast()->getElementType()->kind()) { - case c10::TypeKind::IntType: { - type = 5; - for (const auto& x : value_node->inputs()) { - ai.push_back((int)x->node()->i(torch::jit::attr::value)); - } - break; - } - case c10::TypeKind::FloatType: { - type = 6; - for (const auto& x : value_node->inputs()) { - af.push_back((float)x->node()->f(torch::jit::attr::value)); - } - break; - } - case c10::TypeKind::StringType: { - type = 7; - for (const auto& x : value_node->inputs()) { - as.push_back(x->node()->s(torch::jit::attr::value)); - } - break; - } - default: { - fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); - break; - } - } - } else { - fprintf(stderr, "unknown Parameter value kind %s\n", value_node->kind().toDisplayString()); - } +static int string_to_type(const char* s) +{ + if (strcmp(s, "f32") == 0) return 1; + if (strcmp(s, "f64") == 0) return 2; + if (strcmp(s, "f16") == 0) return 3; + if (strcmp(s, "i32") == 0) return 4; + if (strcmp(s, "i64") == 0) return 5; + if (strcmp(s, "i16") == 0) return 6; + if (strcmp(s, "i8") == 0) return 7; + if (strcmp(s, "u8") == 0) return 8; + if (strcmp(s, "bool") == 0) return 9; + if (strcmp(s, "c64") == 0) return 10; + if (strcmp(s, "c128") == 0) return 11; + if (strcmp(s, "c32") == 0) return 12; + if (strcmp(s, "bf16") == 0) return 13; + return 0; // null } -Parameter::Parameter(const torch::jit::Value* value) : Parameter(value->node()) {} -#endif // BUILD_PNNX +bool operator==(const Parameter& lhs, const Parameter& rhs) +{ + if (lhs.type != rhs.type) + return false; -bool operator==(const Parameter& lhs, const Parameter& rhs) { - if (lhs.type != rhs.type) return false; + if (lhs.type == 0) + return true; - if (lhs.type == 0) return true; + if (lhs.type == 1 && lhs.b == rhs.b) + return true; - if (lhs.type == 1 && lhs.b == rhs.b) return true; + if (lhs.type == 2 && lhs.i == rhs.i) + return true; - if (lhs.type == 2 && lhs.i == rhs.i) return true; + if (lhs.type == 3 && lhs.f == rhs.f) + return true; - if (lhs.type == 3 && lhs.f == rhs.f) return true; + if (lhs.type == 4 && lhs.s == rhs.s) + return true; - if (lhs.type == 4 && lhs.s == rhs.s) return true; + if (lhs.type == 5 && lhs.ai == rhs.ai) + return true; - if (lhs.type == 5 && lhs.ai == rhs.ai) return true; + if (lhs.type == 6 && lhs.af == rhs.af) + return true; - if (lhs.type == 6 && lhs.af == rhs.af) return true; + if (lhs.type == 7 && lhs.as == rhs.as) + return true; - if (lhs.type == 7 && lhs.as == rhs.as) return true; + if (lhs.type == 10 && lhs.c == rhs.c) + return true; - return false; + if (lhs.type == 11 && lhs.ac == rhs.ac) + return true; + + return false; } -#if BUILD_PNNX -Attribute::Attribute(const slice::Tensor& t) { - type = get_at_tensor_type(t.scalar_type()); +Attribute::Attribute(const std::initializer_list& _shape, const std::vector& t) +{ + type = 1; + shape = _shape; - const int ndim = (int)t.dim(); + if (shape.size() > 0) + { + data.resize(elemcount() * type_to_elemsize(type)); + memcpy((void*)data.data(), (const void*)t.data(), data.size()); + } +} - if (ndim == 0) { - shape = {1}; +size_t Attribute::elemsize() const +{ + return type_to_elemsize(type); +} - weight_data.resize(type_to_elemsize(type)); +int Attribute::elemcount() const +{ + if (shape.empty()) + return 0; - if (t.scalar_type() == c10::ScalarType::Long) { - int64_t i = t.item(); - memcpy((void*)weight_data.weight_data(), (const void*)&i, weight_data.size()); - } else if (t.scalar_type() == c10::ScalarType::Int) { - int i = t.item(); - memcpy((void*)weight_data.weight_data(), (const void*)&i, weight_data.size()); - } else if (t.scalar_type() == c10::ScalarType::Double) { - double f = t.item(); - memcpy((void*)weight_data.weight_data(), (const void*)&f, weight_data.size()); - } else if (t.scalar_type() == c10::ScalarType::Float) { - float f = t.item(); - memcpy((void*)weight_data.weight_data(), (const void*)&f, weight_data.size()); - } else { - fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type); + int size = shape[0]; + for (size_t i = 1; i < shape.size(); i++) + { + size *= shape[i]; } - return; - } + return size; +} - shape.resize(ndim); - for (int i = 0; i < ndim; i++) shape[i] = t.size(i); +std::vector Attribute::get_float32_data() const +{ + std::vector v(elemcount()); - if (shape.size() > 0) { - int size = shape[0]; - for (size_t i = 1; i < shape.size(); i++) { - size *= shape[i]; + if (type == 1) + { + memcpy((void*)v.data(), (const void*)data.data(), data.size()); + } + else if (type == 2) + { + // f64 + const double* p = (const double*)data.data(); + for (size_t i = 0; i < v.size(); i++) + { + v[i] = float(p[i]); + } + } + else if (type == 3) + { + // f16 + const unsigned short* p = (const unsigned short*)data.data(); + for (size_t i = 0; i < v.size(); i++) + { + v[i] = float16_to_float32(p[i]); + } + } + else + { + fprintf(stderr, "cannot convert type %d to float32 data\n", type); } - weight_data.resize(size * type_to_elemsize(type)); - memcpy((void*)weight_data.weight_data(), (const void*)t.cpu().contiguous().data_ptr(), - weight_data.size()); - } + return v; } -#endif // BUILD_PNNX -Attribute::Attribute(const std::initializer_list& _shape, const std::vector& t) { - type = 1; - shape = _shape; +void Attribute::set_float32_data(const std::vector& newdata) +{ + data.resize(newdata.size() * elemsize()); - if (shape.size() > 0) { - int size = shape[0]; - for (size_t i = 1; i < shape.size(); i++) { - size *= shape[i]; + if (type == 1) + { + memcpy((void*)data.data(), (const void*)newdata.data(), data.size()); + } + else if (type == 2) + { + // f64 + double* p = (double*)data.data(); + for (size_t i = 0; i < newdata.size(); i++) + { + p[i] = newdata[i]; + } + } + else if (type == 3) + { + // f16 + unsigned short* p = (unsigned short*)data.data(); + for (size_t i = 0; i < newdata.size(); i++) + { + p[i] = float32_to_float16(newdata[i]); + } + } + else + { + fprintf(stderr, "cannot convert float32 data to type %d\n", type); } - - data.resize(size * type_to_elemsize(type)); - memcpy((void*)data.data(), (const void*)t.data(), data.size()); - } } -bool operator==(const Attribute& lhs, const Attribute& rhs) { - if (lhs.type != rhs.type) return false; +bool operator==(const Attribute& lhs, const Attribute& rhs) +{ + if (lhs.type != rhs.type) + return false; - if (lhs.type == 0) return true; + if (lhs.type == 0) + return true; - if (lhs.shape != rhs.shape) return false; + if (lhs.shape != rhs.shape) + return false; - if (lhs.data != rhs.data) return false; + if (lhs.data != rhs.data) + return false; - return true; + return true; } -Attribute operator+(const Attribute& a, const Attribute& b) { - Attribute c; +Attribute operator+(const Attribute& a, const Attribute& b) +{ + Attribute c; - if (a.type != b.type) { - fprintf(stderr, "concat attribute type mismatch\n"); - return c; - } + if (a.type != b.type) + { + fprintf(stderr, "concat attribute type mismatch\n"); + return c; + } - if (a.shape.size() != b.shape.size()) { - fprintf(stderr, "concat attribute shape rank mismatch\n"); - return c; - } + if (a.shape.size() != b.shape.size()) + { + fprintf(stderr, "concat attribute shape rank mismatch\n"); + return c; + } - for (int i = 1; i < (int)a.shape.size(); i++) { - if (a.shape[i] != b.shape[i]) { - fprintf(stderr, "concat attribute shape mismatch\n"); - return c; + for (int i = 1; i < (int)a.shape.size(); i++) + { + if (a.shape[i] != b.shape[i]) + { + fprintf(stderr, "concat attribute shape mismatch\n"); + return c; + } } - } - c.type = a.type; - c.shape = a.shape; - c.shape[0] += b.shape[0]; // concat the first dim + c.type = a.type; + c.shape = a.shape; + c.shape[0] += b.shape[0]; // concat the first dim - c.data.resize(a.data.size() + b.data.size()); - memcpy(c.data.data(), a.data.data(), a.data.size()); - memcpy(c.data.data() + a.data.size(), b.data.data(), b.data.size()); + c.data.resize(a.data.size() + b.data.size()); + memcpy(c.data.data(), a.data.data(), a.data.size()); + memcpy(c.data.data() + a.data.size(), b.data.data(), b.data.size()); - return c; + return c; } -Parameter Parameter::parse_from_string(const std::string& value) { - Parameter p; - p.type = 0; +Parameter Parameter::parse_from_string(const std::string& value) +{ + if (value.find('%') != std::string::npos) + { + Parameter p; + p.type = 4; + p.s = value; + return p; + } - if (value == "None" || value == "()" || value == "[]") { - return p; - } + Parameter p; + p.type = 0; - if (value == "True" || value == "False") { - // bool - p.type = 1; - p.b = value == "True"; - return p; - } + if (value == "None" || value == "()" || value == "[]") + { + return p; + } - if (value[0] == '(' || value[0] == '[') { - // list - std::string lc = value.substr(1, value.size() - 2); - std::istringstream lcss(lc); + if (value == "True" || value == "False") + { + // bool + p.type = 1; + p.b = value == "True"; + return p; + } - while (!lcss.eof()) { - std::string elem; - std::getline(lcss, elem, ','); + if (value[0] == '(' || value[0] == '[') + { + // list + std::string lc = value.substr(1, value.size() - 2); + std::istringstream lcss(lc); + + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) + { + // string + p.type = 7; + p.as.push_back(elem); + } + else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) + { + // float + p.type = 6; + p.af.push_back(std::stof(elem)); + } + else + { + // integer + p.type = 5; + p.ai.push_back(std::stoi(elem)); + } + } + return p; + } - if ((elem[0] != '-' && (elem[0] < '0' || elem[0] > '9')) || - (elem[0] == '-' && (elem[1] < '0' || elem[1] > '9'))) { + if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) + { // string - p.type = 7; - p.as.push_back(elem); - } else if (elem.find('.') != std::string::npos || elem.find('e') != std::string::npos) { - // float - p.type = 6; - p.af.push_back(std::stof(elem)); - } else { - // integer - p.type = 5; - p.ai.push_back(std::stoi(elem)); - } + p.type = 4; + p.s = value; + return p; } - return p; - } - if ((value[0] != '-' && (value[0] < '0' || value[0] > '9')) || - (value[0] == '-' && (value[1] < '0' || value[1] > '9'))) { - // string - p.type = 4; - p.s = value; - return p; - } + if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) + { + // float + p.type = 3; + p.f = std::stof(value); + return p; + } - if (value.find('.') != std::string::npos || value.find('e') != std::string::npos) { - // float - p.type = 3; - p.f = std::stof(value); + // integer + p.type = 2; + p.i = std::stoi(value); return p; - } - - // integer - p.type = 2; - p.i = std::stoi(value); - return p; } -Graph::Graph() {} +std::string Parameter::encode_to_string(const Parameter& param) +{ + if (param.type == 0) + { + return std::string("None"); + } + if (param.type == 1) + { + if (param.b) + return std::string("True"); + else + return std::string("False"); + } + if (param.type == 2) + { + return std::to_string(param.i); + } + if (param.type == 3) + { + char buf[64]; + sprintf(buf, "%e", param.f); + return std::string(buf); + } + if (param.type == 4) + { + return param.s; + } + if (param.type == 5) + { + std::string s("("); + for (size_t i = 0; i < param.ai.size(); i++) + { + s += std::to_string(param.ai[i]); + if (i + 1 != param.ai.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + if (param.type == 6) + { + std::string s("("); + for (size_t i = 0; i < param.af.size(); i++) + { + char buf[64]; + sprintf(buf, "%e", param.af[i]); + s += std::string(buf); + if (i + 1 != param.af.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + if (param.type == 7) + { + std::string s("("); + for (size_t i = 0; i < param.as.size(); i++) + { + s += param.as[i]; + if (i + 1 != param.as.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } + if (param.type == 10) + { + char buf[128]; + sprintf(buf, "%e+%ej", param.c.real(), param.c.imag()); + return std::string(buf); + } + if (param.type == 11) + { + std::string s("("); + for (size_t i = 0; i < param.ac.size(); i++) + { + char buf[128]; + sprintf(buf, "%e+%ej", param.ac[i].real(), param.ac[i].imag()); + s += std::string(buf); + if (i + 1 != param.ac.size()) + s += std::string(","); + } + s += std::string(")"); + return s; + } -Graph::~Graph() { - for (auto x : ops) delete x; + fprintf(stderr, "unknown parameter type %d\n", param.type); + return std::string(); +} - for (auto x : operands) delete x; +bool Operator::has_param(const std::string& key) const +{ + return params.find(key) != params.end(); +} - ops.clear(); - operands.clear(); +bool Operator::has_attr(const std::string& key) const +{ + return attrs.find(key) != attrs.end(); } -Graph::Graph(const Graph& /*rhs*/) {} +bool Operator::has_input(const std::string& key) const +{ + return std::find(inputnames.begin(), inputnames.end(), key) != inputnames.end(); +} -Graph& Graph::operator=(const Graph& /*rhs*/) { return *this; } +Operand* Operator::named_input(const std::string& key) +{ + for (size_t i = 0; i < inputnames.size(); i++) + { + if (inputnames[i] == key) + return inputs[i]; + } -static void load_parameter(Operator* op, const std::string& key, const std::string& value) { - op->params[key] = Parameter::parse_from_string(value); + return 0; } -static void load_input_key(Operator* op, const std::string& key, const std::string& value) { - op->inputnames.resize(op->inputs.size()); - - for (size_t i = 0; i < op->inputs.size(); i++) { - const Operand* oprand = op->inputs[i]; - if (oprand->name == value) { - op->inputnames[i] = key; - break; +const Operand* Operator::named_input(const std::string& key) const +{ + for (size_t i = 0; i < inputnames.size(); i++) + { + if (inputnames[i] == key) + return inputs[i]; } - } + + return 0; } -static void load_shape(Operator* op, const std::string& key, const std::string& value) { - Operand* operand = 0; - for (auto r : op->inputs) { - if (r->name == key) { - operand = r; - break; - } - } - - if (!operand) { - for (auto r : op->outputs) { - if (r->name == key) { - operand = r; - break; - } - } - } - - if (!operand) { - fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); - return; - } - - // type - std::string typestr = value.substr(value.find_last_of(')') + 1); - operand->type = string_to_type(typestr.c_str()); - - // shape - std::string lc = value.substr(1, value.find_last_of(')') - 1); - std::istringstream lcss(lc); - - operand->shape.clear(); - while (!lcss.eof()) { - std::string elem; - std::getline(lcss, elem, ','); - - if (elem == "?") { - operand->shape.push_back(-1); - } else { - int i = std::stoi(elem); - operand->shape.push_back(i); - } - } +Graph::Graph() +{ } -static void load_attribute(Operator* op, const std::string& key, const std::string& value, - StoreZipReader& szr) { - Attribute& a = op->attrs[key]; +Graph::~Graph() +{ + for (auto x : ops) + delete x; - // type - std::string typestr = value.substr(value.find_last_of(')') + 1); - a.type = string_to_type(typestr.c_str()); + for (auto x : operands) + delete x; - if (a.type == 0) return; + ops.clear(); + operands.clear(); +} - // shape - std::string lc = value.substr(1, value.find_last_of(')') - 1); - std::istringstream lcss(lc); +Graph::Graph(const Graph& /*rhs*/) +{ +} - a.shape.clear(); - while (!lcss.eof()) { - std::string elem; - std::getline(lcss, elem, ','); +Graph& Graph::operator=(const Graph& /*rhs*/) +{ + return *this; +} - int i = std::stoi(elem); - a.shape.push_back(i); - } +static void load_parameter(Operator* op, const std::string& key, const std::string& value) +{ + op->params[key] = Parameter::parse_from_string(value); +} - if (a.shape.empty()) return; +static void load_input_key(Operator* op, const std::string& key, const std::string& value) +{ + op->inputnames.resize(op->inputs.size()); - // weight_data - size_t size = 1; - for (int i : a.shape) { - size *= i; - } + for (size_t i = 0; i < op->inputs.size(); i++) + { + const Operand* oprand = op->inputs[i]; + if (oprand->name == value) + { + op->inputnames[i] = key; + break; + } + } +} - size_t bytesize = size * type_to_elemsize(a.type); +static void load_shape(Operator* op, const std::string& key, const std::string& value) +{ + Operand* operand = 0; + for (auto r : op->inputs) + { + if (r->name == key) + { + operand = r; + break; + } + } - std::string filename = op->name + "." + key; + if (!operand) + { + for (auto r : op->outputs) + { + if (r->name == key) + { + operand = r; + break; + } + } + } - size_t filesize = szr.get_file_size(filename); + if (!operand) + { + fprintf(stderr, "no such operand %s for operator %s\n", key.c_str(), op->name.c_str()); + return; + } - if (filesize == 0) { - // no such file - return; - } + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + operand->type = string_to_type(typestr.c_str()); - if (filesize != bytesize) { - fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); - } + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); - a.data.resize(bytesize); - szr.read_file(filename, (char*)a.data.data()); -} + operand->shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); -int Graph::load(const std::string& parampath, const std::string& binpath) { - std::ifstream is(parampath, std::ios::in | std::ios::binary); - if (!is.good()) { - fprintf(stderr, "open failed\n"); - return -1; - } - - StoreZipReader szr; - if (szr.open(binpath) != 0) { - fprintf(stderr, "open failed\n"); - return -1; - } - - int magic = 0; - { - std::string line; - std::getline(is, line); - std::istringstream iss(line); - - iss >> magic; - } - - int operator_count = 0; - int operand_count = 0; - { - std::string line; - std::getline(is, line); - std::istringstream iss(line); - - iss >> operator_count >> operand_count; - } - - for (int i = 0; i < operator_count; i++) { - std::string line; - std::getline(is, line); - std::istringstream iss(line); - - std::string type; - std::string name; - int input_count = 0; - int output_count = 0; - - iss >> type >> name >> input_count >> output_count; - - Operator* op = new_operator(type, name); - - for (int j = 0; j < input_count; j++) { - std::string operand_name; - iss >> operand_name; - - Operand* r = get_operand(operand_name); - r->consumers.push_back(op); - op->inputs.push_back(r); - } - - for (int j = 0; j < output_count; j++) { - std::string operand_name; - iss >> operand_name; - - Operand* r = new_operand(operand_name); - r->producer = op; - op->outputs.push_back(r); - } - - // key=value - while (!iss.eof()) { - std::string param; - iss >> param; - - std::string key; - std::string value; - std::istringstream pss(param); - std::getline(pss, key, '='); - std::getline(pss, value); - - if (key[0] == '@') { - // attribute - load_attribute(op, key.substr(1), value, szr); - } else if (key[0] == '$') { - // operand input key - load_input_key(op, key.substr(1), value); - } else if (key[0] == '#') { - // operand shape - load_shape(op, key.substr(1), value); - } else { - // parameter - load_parameter(op, key, value); - } - } - } - - return 0; + if (elem == "?") + { + operand->shape.push_back(-1); + } + else if (elem[0] == '%') + { + // encode %abc as symbolic tag + operand->shape.push_back(-233); + int index = operand->shape.size() - 1; + std::string key = elem.substr(1); + operand->params[std::string("__shape__") + std::to_string(index)] = key; + } + else + { + int i = std::stoi(elem); + operand->shape.push_back(i); + } + } } -int Graph::save(const std::string& parampath, const std::string& binpath) { - FILE* paramfp = fopen(parampath.c_str(), "wb"); - if (!paramfp) { - fprintf(stderr, "fopen %s failed\n", parampath.c_str()); - return -1; - } +static void load_attribute(Operator* op, const std::string& key, const std::string& value, StoreZipReader& szr) +{ + Attribute& a = op->attrs[key]; - StoreZipWriter szw; - if (szw.open(binpath) != 0) { - fprintf(stderr, "open failed\n"); - return -1; - } + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + a.type = string_to_type(typestr.c_str()); - // magic - fprintf(paramfp, "7767517\n"); + if (a.type == 0) + return; - // op count and oprand count - fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); - for (const Operator* op : ops) { - fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), - (int)op->inputs.size(), (int)op->outputs.size()); + a.shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); - for (const Operand* oprand : op->inputs) { - fprintf(paramfp, " %s", oprand->name.c_str()); + int i = std::stoi(elem); + a.shape.push_back(i); } - for (const Operand* oprand : op->outputs) { - fprintf(paramfp, " %s", oprand->name.c_str()); + if (a.shape.empty()) + return; + + // data + size_t size = 1; + for (int i : a.shape) + { + size *= i; } - for (const auto& it : op->params) { - fprintf(paramfp, " %s=", it.first.c_str()); + size_t bytesize = size * type_to_elemsize(a.type); - const Parameter& param = it.second; - if (param.type == 0) { - fprintf(paramfp, "None"); - } - if (param.type == 1) { - if (param.b) - fprintf(paramfp, "True"); - else - fprintf(paramfp, "False"); - } - if (param.type == 2) { - fprintf(paramfp, "%d", param.i); - } - if (param.type == 3) { - fprintf(paramfp, "%e", param.f); - } - if (param.type == 4) { - fprintf(paramfp, "%s", param.s.c_str()); - } - if (param.type == 5) { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.ai.size(); i++) { - fprintf(paramfp, "%d", param.ai[i]); - if (i + 1 != param.ai.size()) fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } - if (param.type == 6) { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.af.size(); i++) { - fprintf(paramfp, "%e", param.af[i]); - if (i + 1 != param.af.size()) fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } - if (param.type == 7) { - fprintf(paramfp, "("); - for (size_t i = 0; i < param.as.size(); i++) { - fprintf(paramfp, "%s", param.as[i].c_str()); - if (i + 1 != param.as.size()) fprintf(paramfp, ","); - } - fprintf(paramfp, ")"); - } - } - - for (const auto& it : op->attrs) { - fprintf(paramfp, " @%s=", it.first.c_str()); - - const Attribute& attr = it.second; - fprintf(paramfp, "("); - for (int i = 0; i < (int)attr.shape.size() - 1; i++) { - fprintf(paramfp, "%d,", attr.shape[i]); - } - if (attr.shape.size() > 0) fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); - fprintf(paramfp, ")"); - - fprintf(paramfp, type_to_string(attr.type)); - - std::string filename = op->name + "." + it.first; - szw.write_file(filename, attr.data.data(), attr.data.size()); - } - - if (op->inputnames.size() == op->inputs.size()) { - for (size_t i = 0; i < op->inputs.size(); i++) { - if (op->inputnames[i].empty()) continue; + std::string filename = op->name + "." + key; - const Operand* oprand = op->inputs[i]; - fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); - } + size_t filesize = szr.get_file_size(filename); + + if (filesize == 0) + { + // no such file + return; } - for (const Operand* oprand : op->inputs) { - if (oprand->shape.empty()) continue; + if (filesize != bytesize) + { + fprintf(stderr, "file size not match expect %lu but got %lu\n", bytesize, filesize); + } - fprintf(paramfp, " #%s=", oprand->name.c_str()); + a.data.resize(bytesize); + szr.read_file(filename, (char*)a.data.data()); +} - fprintf(paramfp, "("); - for (int i = 0; i < (int)oprand->shape.size() - 1; i++) { - if (oprand->shape[i] == -1) - fprintf(paramfp, "?,"); - else - fprintf(paramfp, "%d,", oprand->shape[i]); - } - if (oprand->shape.size() > 0) { - if (oprand->shape[oprand->shape.size() - 1] == -1) - fprintf(paramfp, "?"); - else - fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); - } - fprintf(paramfp, ")"); +int Graph::load(const std::string& parampath, const std::string& binpath) +{ + std::ifstream is(parampath, std::ios::in | std::ios::binary); + if (!is.good()) + { + fprintf(stderr, "open failed\n"); + return -1; + } - fprintf(paramfp, type_to_string(oprand->type)); + StoreZipReader szr; + if (szr.open(binpath) != 0) + { + fprintf(stderr, "open failed\n"); + return -1; } - for (const Operand* oprand : op->outputs) { - if (oprand->shape.empty()) continue; + int magic = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); - fprintf(paramfp, " #%s=", oprand->name.c_str()); + iss >> magic; + } - fprintf(paramfp, "("); - for (int i = 0; i < (int)oprand->shape.size() - 1; i++) { - if (oprand->shape[i] == -1) - fprintf(paramfp, "?,"); - else - fprintf(paramfp, "%d,", oprand->shape[i]); - } - if (oprand->shape.size() > 0) { - if (oprand->shape[oprand->shape.size() - 1] == -1) - fprintf(paramfp, "?"); - else - fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); - } - fprintf(paramfp, ")"); + int operator_count = 0; + int operand_count = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); - fprintf(paramfp, type_to_string(oprand->type)); + iss >> operator_count >> operand_count; } - fprintf(paramfp, "\n"); - } + for (int i = 0; i < operator_count; i++) + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); - fclose(paramfp); + std::string type; + std::string name; + int input_count = 0; + int output_count = 0; - return 0; -} + iss >> type >> name >> input_count >> output_count; -static std::string sanitize_identifier(const std::string& s) { - std::string ss = s; - for (size_t i = 0; i < ss.size(); i++) { - if (ss[i] == '.' || ss[i] == ':') ss[i] = '_'; - } + Operator* op = new_operator(type, name); - return ss; -} + for (int j = 0; j < input_count; j++) + { + std::string operand_name; + iss >> operand_name; -static std::string expand_expression(const Operator* op) { - std::string expr = op->params.at("expr").s; - - // split into tokens - std::vector tokens; - { - std::string t; - for (size_t i = 0; i < expr.size(); i++) { - char ch = expr[i]; - - if (ch == '[') // list - { - t += ch; - tokens.push_back(t); - t.clear(); - } else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') { - if (!t.empty()) { - tokens.push_back(t); - t.clear(); - } - } else { - t += ch; - } - } - - if (!t.empty()) { - tokens.push_back(t); - } - } - - // scan and stack - std::stack exprstack; - for (int i = (int)tokens.size() - 1; i >= 0; i--) { - const std::string& t = tokens[i]; - - if (t == "size") { - std::string a = exprstack.top(); - exprstack.pop(); - std::string b = exprstack.top(); - exprstack.pop(); - - std::string r = a + ".size(" + b + ")"; - exprstack.push(r); - } else if (t == "int" || t == "abs" || t == "acos" || t == "acosh" || t == "asin" || - t == "asinh" || t == "atan" || t == "atanh" || t == "ceil" || t == "cos" || - t == "cosh" || t == "exp" || t == "floor" || t == "log" || t == "neg" || - t == "reciprocal" || t == "rsqrt" || t == "sign" || t == "sin" || t == "sinh" || - t == "sqrt" || t == "square" || t == "tan" || t == "tanh" || t == "trunc") { - std::string unaryop; - if (t == "int") unaryop = "int"; - if (t == "abs") unaryop = "torch.abs"; - if (t == "acos") unaryop = "torch.acos"; - if (t == "acosh") unaryop = "torch.acosh"; - if (t == "asin") unaryop = "torch.asin"; - if (t == "asinh") unaryop = "torch.asinh"; - if (t == "atan") unaryop = "torch.atan"; - if (t == "atanh") unaryop = "torch.atanh"; - if (t == "ceil") unaryop = "torch.ceil"; - if (t == "cos") unaryop = "torch.cos"; - if (t == "cosh") unaryop = "torch.cosh"; - if (t == "exp") unaryop = "torch.exp"; - if (t == "floor") unaryop = "torch.floor"; - if (t == "log") unaryop = "torch.log"; - if (t == "neg") unaryop = "torch.neg"; - if (t == "reciprocal") unaryop = "torch.reciprocal"; - if (t == "rsqrt") unaryop = "torch.rsqrt"; - if (t == "sign") unaryop = "torch.sign"; - if (t == "sin") unaryop = "torch.sin"; - if (t == "sinh") unaryop = "torch.sinh"; - if (t == "sqrt") unaryop = "torch.sqrt"; - if (t == "square") unaryop = "torch.square"; - if (t == "tan") unaryop = "torch.tan"; - if (t == "tanh") unaryop = "torch.tanh"; - if (t == "trunc") unaryop = "torch.trunc"; - - std::string a = exprstack.top(); - exprstack.pop(); - - std::string r = unaryop + "(" + a + ")"; - exprstack.push(r); - } else if (t == "atan2" || t == "pow") { - std::string binaryop; - if (t == "atan2") binaryop = "torch.atan2"; - if (t == "pow") binaryop = "torch.pow"; - - std::string a = exprstack.top(); - exprstack.pop(); - std::string b = exprstack.top(); - exprstack.pop(); - - std::string r = binaryop + "(" + a + ", " + b + ")"; - exprstack.push(r); - } else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || - t == "and" || t == "or" || t == "xor") { - std::string binaryop; - if (t == "add") binaryop = "+"; - if (t == "sub") binaryop = "-"; - if (t == "mul") binaryop = "*"; - if (t == "div") binaryop = "/"; - if (t == "floor_divide") binaryop = "//"; - if (t == "and") binaryop = "&"; - if (t == "or") binaryop = "|"; - if (t == "xor") binaryop = "^"; - - std::string a = exprstack.top(); - exprstack.pop(); - std::string b = exprstack.top(); - exprstack.pop(); - - std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; - exprstack.push(r); - } else if (t == "[") // list - { - std::vector elements; - while (!exprstack.empty()) { - std::string a = exprstack.top(); - exprstack.pop(); - - elements.push_back(a); - } - - std::string r = "["; - for (int j = 0; j < (int)elements.size() - 1; j++) { - r += elements[j]; - if (j + 1 != (int)elements.size()) r += ", "; - } - if (!elements.empty()) { - r += elements[elements.size() - 1]; - } - r += "]"; - - exprstack.push(r); - } else if (t[0] == '@') { - int input_index = std::stoi(t.substr(1)); - std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); - exprstack.push(varid); - } else { - // literal - exprstack.push(t); - } - } - - std::string r = exprstack.top(); - exprstack.pop(); - - return r; -} + Operand* r = get_operand(operand_name); + r->consumers.push_back(op); + op->inputs.push_back(r); + } -static std::string make_slice_expression(const Operator* op) { - for (size_t j = 0; j < op->inputnames.size(); j++) { - fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), - op->inputs[j]->name.c_str()); - } + for (int j = 0; j < output_count; j++) + { + std::string operand_name; + iss >> operand_name; - std::vector dims; - if (op->params.find("dims") != op->params.end()) { - dims = op->params.at("dims").ai; - } else { - dims.push_back(op->params.at("dim").i); - } + Operand* r = new_operand(operand_name); + r->producer = op; + op->outputs.push_back(r); + } - std::string r; + // key=value + while (!iss.eof()) + { + std::string param; + iss >> param; + + std::string key; + std::string value; + std::istringstream pss(param); + std::getline(pss, key, '='); + std::getline(pss, value); + + if (key[0] == '@') + { + // attribute + load_attribute(op, key.substr(1), value, szr); + } + else if (key[0] == '$') + { + // operand input key + load_input_key(op, key.substr(1), value); + } + else if (key[0] == '#') + { + // operand shape + load_shape(op, key.substr(1), value); + } + else + { + // parameter + load_parameter(op, key, value); + } + } + } - int last_dim = -1; - const int ndim = (int)dims.size(); - for (int i = 0; i < ndim; i++) { - int dim = dims[i]; - for (int j = last_dim + 1; j < dim; j++) { - r += ":,"; + return 0; +} + +int Graph::save(const std::string& parampath, const std::string& binpath) +{ + FILE* paramfp = fopen(parampath.c_str(), "wb"); + if (!paramfp) + { + fprintf(stderr, "fopen %s failed\n", parampath.c_str()); + return -1; } - last_dim = dim; - if (op->params.find("starts") != op->params.end()) { - std::vector starts = op->params.at("starts").ai; - int start = starts[i]; + StoreZipWriter szw; + if (szw.open(binpath) != 0) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + // magic + fprintf(paramfp, "7767517\n"); + + // op count and oprand count + fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); - if (start != 0) r += std::to_string(start); - } else { - fprintf(stderr, "find start\n"); - // find start - for (size_t j = 0; j < op->inputnames.size(); j++) { - if (op->inputnames[j] == "start") { - r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); + for (const Operator* op : ops) + { + fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); - fprintf(stderr, "find start %s\n", op->inputs[j]->name.c_str()); - break; + for (const Operand* oprand : op->inputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const Operand* oprand : op->outputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); } - } - } - r += ':'; + for (const auto& it : op->params) + { + fprintf(paramfp, " %s=", it.first.c_str()); - if (op->params.find("ends") != op->params.end()) { - std::vector ends = op->params.at("ends").ai; - int end = ends[i]; - if (end != INT_MAX) r += std::to_string(end); - } else { - // find end - for (size_t j = 0; j < op->inputnames.size(); j++) { - if (op->inputnames[j] == "end") { - r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); - break; + const Parameter& param = it.second; + std::string s = Parameter::encode_to_string(param); + fprintf(paramfp, "%s", s.c_str()); } - } - } - if (op->params.find("steps") != op->params.end()) { - std::vector steps = op->params.at("steps").ai; - int step = steps[i]; - if (step != 1) { - r += ':'; - r += std::to_string(step); - } - } else { - // find step - for (size_t j = 0; j < op->inputnames.size(); j++) { - if (op->inputnames[j] == "step") { - r += ':'; - r += std::string("v_") + sanitize_identifier(op->inputs[j]->name); - break; + for (const auto& it : op->attrs) + { + fprintf(paramfp, " @%s=", it.first.c_str()); + + const Attribute& attr = it.second; + fprintf(paramfp, "("); + for (int i = 0; i < (int)attr.shape.size() - 1; i++) + { + fprintf(paramfp, "%d,", attr.shape[i]); + } + if (attr.shape.size() > 0) + fprintf(paramfp, "%d", attr.shape[attr.shape.size() - 1]); + fprintf(paramfp, ")"); + + fprintf(paramfp, type_to_string(attr.type)); + + std::string filename = op->name + "." + it.first; + szw.write_file(filename, attr.data.data(), attr.data.size()); + } + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + const Operand* oprand = op->inputs[i]; + fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); + } } - } + + for (const Operand* oprand : op->inputs) + { + if (oprand->shape.empty()) + continue; + + fprintf(paramfp, " #%s=", oprand->name.c_str()); + + fprintf(paramfp, "("); + for (int i = 0; i < (int)oprand->shape.size() - 1; i++) + { + if (oprand->shape[i] == -1) + fprintf(paramfp, "?,"); + else + fprintf(paramfp, "%d,", oprand->shape[i]); + } + if (oprand->shape.size() > 0) + { + if (oprand->shape[oprand->shape.size() - 1] == -1) + fprintf(paramfp, "?"); + else + fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + } + fprintf(paramfp, ")"); + + fprintf(paramfp, type_to_string(oprand->type)); + } + + for (const Operand* oprand : op->outputs) + { + if (oprand->shape.empty()) + continue; + + fprintf(paramfp, " #%s=", oprand->name.c_str()); + + fprintf(paramfp, "("); + for (int i = 0; i < (int)oprand->shape.size() - 1; i++) + { + if (oprand->shape[i] == -1) + fprintf(paramfp, "?,"); + else + fprintf(paramfp, "%d,", oprand->shape[i]); + } + if (oprand->shape.size() > 0) + { + if (oprand->shape[oprand->shape.size() - 1] == -1) + fprintf(paramfp, "?"); + else + fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + } + fprintf(paramfp, ")"); + + fprintf(paramfp, type_to_string(oprand->type)); + } + + fprintf(paramfp, "\n"); } - if (i + 1 != ndim) r += ','; - } + fclose(paramfp); - return r; + return 0; } -static std::string make_index_expression(const Operator* op) { - fprintf(stderr, "make_index_expression %s\n", op->name.c_str()); +static std::string sanitize_identifier(const std::string& s) +{ + std::string ss = s; + for (size_t i = 0; i < ss.size(); i++) + { + if (ss[i] == '.' || ss[i] == ':' || ss[i] == '/') + ss[i] = '_'; + } + + return ss; +} + +static std::string expand_expression(const Operator* op) +{ + std::string expr = op->params.at("expr").s; + + // split into tokens + std::vector tokens; + { + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + } - std::string index_expr = op->params.at("expr").s; + // scan and stack + std::stack exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string& t = tokens[i]; + + if (t == "size") + { + std::string a = exprstack.top(); + exprstack.pop(); + + if (exprstack.empty()) + { + std::string r = a + ".shape"; + exprstack.push(r); + } + else + { + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = a + ".size(" + b + ")"; + exprstack.push(r); + } + } + else if (t == "int" + || t == "abs" + || t == "acos" + || t == "acosh" + || t == "asin" + || t == "asinh" + || t == "atan" + || t == "atanh" + || t == "ceil" + || t == "cos" + || t == "cosh" + || t == "exp" + || t == "floor" + || t == "log" + || t == "log10" + || t == "neg" + || t == "reciprocal" + || t == "round" + || t == "rsqrt" + || t == "sign" + || t == "sin" + || t == "sinh" + || t == "sqrt" + || t == "square" + || t == "tan" + || t == "tanh" + || t == "trunc" + || t == "torch.bool" + || t == "torch.float" + || t == "torch.long") + { + std::string unaryop = t; + if (t == "int") unaryop = "int"; + if (t == "abs") unaryop = "torch.abs"; + if (t == "acos") unaryop = "torch.acos"; + if (t == "acosh") unaryop = "torch.acosh"; + if (t == "asin") unaryop = "torch.asin"; + if (t == "asinh") unaryop = "torch.asinh"; + if (t == "atan") unaryop = "torch.atan"; + if (t == "atanh") unaryop = "torch.atanh"; + if (t == "ceil") unaryop = "torch.ceil"; + if (t == "cos") unaryop = "torch.cos"; + if (t == "cosh") unaryop = "torch.cosh"; + if (t == "exp") unaryop = "torch.exp"; + if (t == "floor") unaryop = "torch.floor"; + if (t == "log") unaryop = "torch.log"; + if (t == "log10") unaryop = "torch.log10"; + if (t == "neg") unaryop = "-"; + if (t == "reciprocal") unaryop = "torch.reciprocal"; + if (t == "round") unaryop = "torch.round"; + if (t == "rsqrt") unaryop = "torch.rsqrt"; + if (t == "sign") unaryop = "torch.sign"; + if (t == "sin") unaryop = "torch.sin"; + if (t == "sinh") unaryop = "torch.sinh"; + if (t == "sqrt") unaryop = "torch.sqrt"; + if (t == "square") unaryop = "torch.square"; + if (t == "tan") unaryop = "torch.tan"; + if (t == "tanh") unaryop = "torch.tanh"; + if (t == "trunc") unaryop = "torch.trunc"; + + std::string a = exprstack.top(); + exprstack.pop(); + + std::string r = unaryop + "(" + a + ")"; + exprstack.push(r); + } + else if (t == "atan2" + || t == "fmod" + || t == "max" + || t == "maximum" + || t == "min" + || t == "minimum" + || t == "pow" + || t == "logaddexp") + { + std::string binaryop; + if (t == "atan2") binaryop = "torch.atan2"; + if (t == "fmod") binaryop = "torch.fmod"; + if (t == "max") binaryop = "torch.max"; + if (t == "maximum") binaryop = "torch.maximum"; + if (t == "min") binaryop = "torch.min"; + if (t == "minimum") binaryop = "torch.minimum"; + if (t == "pow") binaryop = "torch.pow"; + if (t == "logaddexp") binaryop = "torch.logaddexp"; + + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = binaryop + "(" + a + ", " + b + ")"; + exprstack.push(r); + } + else if (t == "add" + || t == "sub" + || t == "mul" + || t == "div" + || t == "floor_divide" + || t == "remainder" + || t == "and" + || t == "or" + || t == "xor" + || t == "lshift" + || t == "rshift") + { + std::string binaryop; + if (t == "add") binaryop = "+"; + if (t == "sub") binaryop = "-"; + if (t == "mul") binaryop = "*"; + if (t == "div") binaryop = "/"; + if (t == "floor_divide") binaryop = "//"; + if (t == "remainder") binaryop = "%"; + if (t == "and") binaryop = "&"; + if (t == "or") binaryop = "|"; + if (t == "xor") binaryop = "^"; + if (t == "lshift") binaryop = "<<"; + if (t == "rshift") binaryop = ">>"; + + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = std::string("(") + a + " " + binaryop + " " + b + ")"; + exprstack.push(r); + } + else if (t == "[") // list + { + std::vector elements; + while (!exprstack.empty()) + { + std::string a = exprstack.top(); + exprstack.pop(); + + elements.push_back(a); + } - // strip out-most [ ] pair - index_expr = index_expr.substr(1, index_expr.size() - 2); + std::string r = "["; + for (int j = 0; j < (int)elements.size() - 1; j++) + { + r += elements[j]; + if (j + 1 != (int)elements.size()) + r += ", "; + } + if (!elements.empty()) + { + r += elements[elements.size() - 1]; + } + r += "]"; - // None,None, -> ..., - bool leading_none = false; - while (index_expr.substr(0, 5) == "None,") { - leading_none = true; - index_expr = index_expr.substr(5); - } - if (leading_none) { - index_expr = "...," + index_expr; - } + exprstack.push(r); + } + else if (t[0] == '@') + { + int input_index = std::stoi(t.substr(1)); + std::string varid = std::string("v_") + sanitize_identifier(op->inputs[input_index]->name); + exprstack.push(varid); + } + else + { + // literal + if (t[t.size() - 1] == 'j') + { + // complex + std::string r = std::string("(") + t + ")"; + exprstack.push(r); + } + else + { + exprstack.push(t); + } + } + } - return index_expr; + std::string r = exprstack.top(); + exprstack.pop(); + + return r; } -int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { - FILE* pyfp = fopen(pypath.c_str(), "wb"); - if (!pyfp) { - fprintf(stderr, "fopen %s failed\n", pypath.c_str()); - return -1; - } - - fprintf(pyfp, "import os\n"); - fprintf(pyfp, "import numpy as np\n"); - fprintf(pyfp, "import tempfile, zipfile\n"); - fprintf(pyfp, "import torch\n"); - fprintf(pyfp, "import torch.nn as nn\n"); - fprintf(pyfp, "import torch.nn.functional as F\n"); - fprintf(pyfp, "try:\n"); - fprintf(pyfp, " import torchvision\n"); - fprintf(pyfp, "except:\n"); - fprintf(pyfp, " pass\n"); - - fprintf(pyfp, "\n"); - - fprintf(pyfp, "class Model(nn.Module):\n"); - fprintf(pyfp, " def __init__(self):\n"); - fprintf(pyfp, " super(Model, self).__init__()\n"); - - fprintf(pyfp, "\n"); - - // module - { - for (const Operator* op : ops) { - if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; - - fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), - op->type.c_str()); - - int param_count = op->params.size(); - if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { - param_count -= 2; // ignore scale and zero_point - } - - int param_index = 0; - for (const auto& it : op->params) { - if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { - if (it.first == "scale" || it.first == "zero_point") continue; - } - - fprintf(pyfp, "%s=", it.first.c_str()); - - const Parameter& param = it.second; - if (param.type == 0) { - fprintf(pyfp, "None"); - } - if (param.type == 1) { - if (param.b) - fprintf(pyfp, "True"); - else - fprintf(pyfp, "False"); - } - if (param.type == 2) { - fprintf(pyfp, "%d", param.i); - } - if (param.type == 3) { - fprintf(pyfp, "%f", param.f); - } - if (param.type == 4) { - if (param.s.substr(0, 6) == "torch.") { - fprintf(pyfp, "%s", param.s.c_str()); - } else { - fprintf(pyfp, "\'%s\'", param.s.c_str()); - } - } - if (param.type == 5) { - fprintf(pyfp, "("); - for (size_t i = 0; i < param.ai.size(); i++) { - fprintf(pyfp, "%d", param.ai[i]); - if (i + 1 != param.ai.size() || param.ai.size() == 1) fprintf(pyfp, ","); - } - fprintf(pyfp, ")"); - } - if (param.type == 6) { - fprintf(pyfp, "("); - for (size_t i = 0; i < param.af.size(); i++) { - fprintf(pyfp, "%f", param.af[i]); - if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); - } - fprintf(pyfp, ")"); - } - if (param.type == 7) { - fprintf(pyfp, "("); - for (size_t i = 0; i < param.as.size(); i++) { - if (param.as[i].substr(0, 6) == "torch.") { - fprintf(pyfp, "%s", param.as[i].c_str()); - } else { - fprintf(pyfp, "\'%s\'", param.as[i].c_str()); - } - if (i + 1 != param.as.size() || param.as.size() == 1) fprintf(pyfp, ","); - } - fprintf(pyfp, ")"); - } - - param_index++; - if (param_index != param_count) fprintf(pyfp, ", "); - } - - fprintf(pyfp, ")\n"); - } - } - - fprintf(pyfp, "\n"); - - // load weights - { - fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); - - for (const Operator* op : ops) { - if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") continue; - - if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") { - for (const auto& it : op->attrs) { - if (it.first == "weight" || it.first == "bias") { - fprintf(pyfp, - " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", - sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), - it.first.c_str()); - } else { - // unknown attr - continue; - } - - const Attribute& attr = it.second; - for (size_t i = 0; i < attr.shape.size(); i++) { - fprintf(pyfp, "%d", attr.shape[i]); - if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); - } - - fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); - } - - fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", - sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), - sanitize_identifier(op->name).c_str()); - fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), - op->params.at("scale").f); - fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), - op->params.at("zero_point").i); - - continue; - } - - for (const auto& it : op->attrs) { - if (it.first == "running_mean" || it.first == "running_var") { - fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", - sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), - it.first.c_str()); - } else { - fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", - sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), - it.first.c_str()); - } - - const Attribute& attr = it.second; - for (size_t i = 0; i < attr.shape.size(); i++) { - fprintf(pyfp, "%d", attr.shape[i]); - if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); - } - - if (attr.type == 1 || attr.type == 2 || attr.type == 3) { - fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); - } else { - fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); - } - } - } - - for (const Operator* op : ops) { - if (op->type != "pnnx.Attribute") continue; - - const std::string& key = op->attrs.begin()->first; - const Attribute& attr = op->attrs.begin()->second; - - bool is_running_mean_var = false; - { - const Operand* r = op->outputs[0]; - if (r->consumers.size() == 1) { - const Operator* op2 = r->consumers[0]; - if (op2->type == "F.batch_norm" || op2->type == "F.instance_norm") { - if (r == op2->inputs[1] || r == op2->inputs[2]) { - is_running_mean_var = true; - } - } - } - } - - if (is_running_mean_var) { - fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", - sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), - op->name.c_str(), key.c_str()); - } else { - fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", - sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), - op->name.c_str(), key.c_str()); - } - - for (size_t i = 0; i < attr.shape.size(); i++) { - fprintf(pyfp, "%d", attr.shape[i]); - if (i + 1 != attr.shape.size()) fprintf(pyfp, ","); - } - - if (attr.type == 1 || attr.type == 2 || attr.type == 3) { - fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); - } else { - fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); - } - } - - fprintf(pyfp, " archive.close()\n"); - } - - fprintf(pyfp, "\n"); - - // utility function - { - fprintf(pyfp, - " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, " - "requires_grad=True):\n"); - fprintf(pyfp, - " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), " - "requires_grad)\n"); - fprintf(pyfp, "\n"); - fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); - fprintf(pyfp, " _, tmppath = tempfile.mkstemp()\n"); - fprintf(pyfp, " tmpf = open(tmppath, 'wb')\n"); - fprintf(pyfp, " with archive.open(key) as keyfile:\n"); - fprintf(pyfp, " tmpf.write(keyfile.read())\n"); - fprintf(pyfp, " tmpf.close()\n"); - fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); - fprintf(pyfp, " os.remove(tmppath)\n"); - fprintf(pyfp, " return torch.from_numpy(m)\n"); - } - - fprintf(pyfp, "\n"); - - // def forward - { - fprintf(pyfp, " def forward(self"); - - for (const Operator* op : ops) { - if (op->type != "pnnx.Input") continue; - - fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); - } - - fprintf(pyfp, "):\n"); - } - - // forward body - { - for (const Operator* op : ops) { - if (op->type == "pnnx.Input" || op->type == "pnnx.Output") continue; - - fprintf(pyfp, " "); - - if (op->type == "pnnx.Expression") { - // expr - for (size_t i = 0; i < op->outputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); - if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); - } - std::string expanded_expr = expand_expression(op); - fprintf(pyfp, " = %s\n", expanded_expr.c_str()); - } else if (op->type == "pnnx.Attribute") { - const std::string& key = op->attrs.begin()->first; - fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); - } else if (op->type == "Tensor.slice") { - // slice expr - std::string slice_expr = make_slice_expression(op); - fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); - } else if (op->type == "Tensor.slice_copy") { - // slice copy expr - std::string slice_expr = make_slice_expression(op); - fprintf(pyfp, "v_%s = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->inputs[0]->name).c_str()); - fprintf(pyfp, " v_%s[%s] = v_%s\n", - sanitize_identifier(op->outputs[0]->name).c_str(), slice_expr.c_str(), - sanitize_identifier(op->inputs[1]->name).c_str()); - } else if (op->type == "Tensor.index") { - // index expr - if (op->inputs.size() == 2) { - std::string expanded_expr = expand_expression(op->inputs[1]->producer); - fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); - } else { - std::string index_expr = make_index_expression(op); - fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); - } - } else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") { - // view reshape - fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); - if (op->inputs.size() == 2) { - fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); - } else { - const std::vector& shape = op->params.at("shape").ai; - for (size_t i = 0; i < shape.size(); i++) { - fprintf(pyfp, "%d", shape[i]); - if (i + 1 != shape.size()) fprintf(pyfp, ", "); - } +static std::string make_slice_expression(const Operator* op) +{ + // for (size_t j = 0; j < op->inputnames.size(); j++) + // { + // fprintf(stderr, "make_slice_expression %s %s\n", op->inputnames[j].c_str(), op->inputs[j]->name.c_str()); + // } + + std::vector dims; + if (op->has_param("dims")) + { + dims = op->params.at("dims").ai; + } + else + { + dims.push_back(op->params.at("dim").i); + } + + std::string pr; + std::string nr; + + int last_dim = -1; + const int ndim = (int)dims.size(); + for (int i = 0; i < ndim; i++) + { + int dim = dims[i]; + std::string& r = dim < 0 ? nr : pr; + + for (int j = last_dim + 1; j < dim; j++) + { + r += ":,"; } - fprintf(pyfp, ")\n"); - } else if (op->type == "Tensor.repeat") { - // view reshape - fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); - if (op->inputs.size() == 2) { - fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); - } else { - const std::vector& sizes = op->params.at("sizes").ai; - for (size_t i = 0; i < sizes.size(); i++) { - fprintf(pyfp, "%d", sizes[i]); - if (i + 1 != sizes.size()) fprintf(pyfp, ", "); - } + last_dim = dim; + + bool is_select = false; + if (op->has_param("select")) + { + int select = op->params.at("select").i; + if (select != INT_MAX) + { + r += std::to_string(select); + is_select = true; + } + } + if (op->has_param("selects")) + { + std::vector selects = op->params.at("selects").ai; + int select = selects[i]; + if (select != INT_MAX) + { + r += std::to_string(select); + is_select = true; + } + } + if (op->has_input("select")) + { + r += std::string("v_") + sanitize_identifier(op->named_input("select")->name); + is_select = true; + } + if (op->has_input("selects")) + { + // must be pnnx.SliceIndexes + const Operator* op_sliceindexes = op->named_input("selects")->producer; + const std::string& index = op_sliceindexes->params.at("indexes").as[i]; + if (index[0] == '@') + { + int selecti = std::stoi(index.substr(1)); + r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[selecti]->name); + is_select = true; + } + else + { + int select = std::stoi(index); + if (select != INT_MAX) + { + r += std::to_string(select); + is_select = true; + } + } } - fprintf(pyfp, ")\n"); - } else if (op->type == "torch.cat" || op->type == "torch.stack") { - // cat - fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), - op->type.c_str()); - if (op->inputs.size() == 1) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); - } else { - fprintf(pyfp, "("); - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); - } - fprintf(pyfp, ")"); - } - fprintf(pyfp, ", dim=%d", op->params.at("dim").i); - fprintf(pyfp, ")\n"); - } else if (op->type == "torch.einsum") { - // einsum - fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), - op->type.c_str()); - fprintf(pyfp, "\'%s\'", op->params.at("equation").s.c_str()); + if (is_select) + { + if (i + 1 != ndim) + r += ','; + continue; + } - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (op->has_param("start")) + { + int start = op->params.at("start").i; + if (start != 0) + r += std::to_string(start); } - fprintf(pyfp, ")\n"); - } else if (op->type == "prim::TupleUnpack") { - for (size_t i = 0; i < op->outputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); - if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); - } - fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); - } else if (op->type == "prim::TupleConstruct") { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); - fprintf(pyfp, " = ("); - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + else if (op->has_param("starts")) + { + std::vector starts = op->params.at("starts").ai; + int start = starts[i]; + if (start != 0) + r += std::to_string(start); } - fprintf(pyfp, ")\n"); - } else if (op->type == "prim::ListUnpack") { - for (size_t i = 0; i < op->outputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); - if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); - } - fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); - } else if (op->type == "prim::ListConstruct") { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); - fprintf(pyfp, " = ["); - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); - } - fprintf(pyfp, "]\n"); - } else if (op->type == "nn.LSTM") { - if (op->outputs.size() == 1) { - fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); - } else { - fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), - sanitize_identifier(op->outputs[1]->name).c_str(), - sanitize_identifier(op->outputs[2]->name).c_str()); - } - fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); - if (op->inputs.size() == 3) { - fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), - sanitize_identifier(op->inputs[2]->name).c_str()); + else if (op->has_input("start")) + { + r += std::string("v_") + sanitize_identifier(op->named_input("start")->name); } - fprintf(pyfp, ")\n"); - } else if (op->type == "nn.MultiheadAttention") { - if (op->outputs.size() == 1) { - fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); - } else { - for (size_t i = 0; i < op->outputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); - if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); - } - } - fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); - if (op->inputs.size() == 1) { - const char* in0 = sanitize_identifier(op->inputs[0]->name).c_str(); - fprintf(pyfp, "v_%s, v_%s, v_%s", in0, in0, in0); - } else { - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); - } + else // if (op->has_input("starts")) + { + // must be pnnx.SliceIndexes + const Operator* op_sliceindexes = op->named_input("starts")->producer; + const std::string& index = op_sliceindexes->params.at("indexes").as[i]; + if (index[0] == '@') + { + int starti = std::stoi(index.substr(1)); + r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[starti]->name); + } + else + { + int start = std::stoi(index); + if (start != 0) + r += std::to_string(start); + } } - fprintf(pyfp, ")\n"); - } else if (op->type.substr(0, 3) == "nn." || op->type.substr(0, 16) == "torchvision.ops.") { - // self.xxx() - for (size_t i = 0; i < op->outputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); - if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); + + r += ':'; + + if (op->has_param("end")) + { + int end = op->params.at("end").i; + if (end != INT_MAX) + r += std::to_string(end); } - fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); + else if (op->has_param("ends")) + { + std::vector ends = op->params.at("ends").ai; + int end = ends[i]; + if (end != INT_MAX) + r += std::to_string(end); } - fprintf(pyfp, ")\n"); - } else if (op->type.find("::") != std::string::npos || - op->type.find(".") != std::string::npos) { - // direct - for (size_t i = 0; i < op->outputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); - if (i + 1 != op->outputs.size()) fprintf(pyfp, ", "); - } - - if (op->type.substr(0, 7) == "Tensor.") { - fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), - op->type.substr(7).c_str()); - - for (size_t i = 1; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); - } - } else { - fprintf(pyfp, " = %s(", op->type.c_str()); - - if (op->inputnames.size() == op->inputs.size()) { - for (size_t i = 0; i < op->inputs.size(); i++) { - if (!op->inputnames[i].empty()) continue; - - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); - } - - for (size_t i = 0; i < op->inputs.size(); i++) { - if (op->inputnames[i].empty()) continue; - - fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), - sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); - } - } else { - for (size_t i = 0; i < op->inputs.size(); i++) { - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); - if (i + 1 != op->inputs.size()) fprintf(pyfp, ", "); - } - } - } - - int i = 0; - for (const auto& it : op->params) { - if (op->type.substr(0, 7) == "Tensor." && i == 0) { - fprintf(pyfp, "%s=", it.first.c_str()); - } else if (op->inputs.empty() && i == 0) { - fprintf(pyfp, "%s=", it.first.c_str()); - } else { - fprintf(pyfp, ", %s=", it.first.c_str()); - } - - i++; - - const Parameter& param = it.second; - if (param.type == 0) { - fprintf(pyfp, "None"); - } - if (param.type == 1) { - if (param.b) - fprintf(pyfp, "True"); + else if (op->has_input("end")) + { + r += std::string("v_") + sanitize_identifier(op->named_input("end")->name); + } + else // if (op->has_input("ends")) + { + // must be pnnx.SliceIndexes + const Operator* op_sliceindexes = op->named_input("ends")->producer; + const std::string& index = op_sliceindexes->params.at("indexes").as[i]; + if (index[0] == '@') + { + int endi = std::stoi(index.substr(1)); + r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[endi]->name); + } else - fprintf(pyfp, "False"); - } - if (param.type == 2) { - fprintf(pyfp, "%d", param.i); - } - if (param.type == 3) { - fprintf(pyfp, "%f", param.f); - } - if (param.type == 4) { - if (param.s.substr(0, 6) == "torch.") { - fprintf(pyfp, "%s", param.s.c_str()); - } else { - fprintf(pyfp, "\'%s\'", param.s.c_str()); - } - } - if (param.type == 5) { - fprintf(pyfp, "("); - for (size_t i = 0; i < param.ai.size(); i++) { - fprintf(pyfp, "%d", param.ai[i]); - if (i + 1 != param.ai.size() || param.ai.size() == 1) fprintf(pyfp, ","); + { + int end = std::stoi(index); + if (end != INT_MAX) + r += std::to_string(end); } - fprintf(pyfp, ")"); - } - if (param.type == 6) { - fprintf(pyfp, "("); - for (size_t i = 0; i < param.af.size(); i++) { - fprintf(pyfp, "%f", param.af[i]); - if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); + } + + if (op->has_param("step")) + { + int step = op->params.at("step").i; + if (step != 1) + { + r += ':'; + r += std::to_string(step); } - fprintf(pyfp, ")"); - } - if (param.type == 7) { - fprintf(pyfp, "("); - for (size_t i = 0; i < param.as.size(); i++) { - if (param.as[i].substr(0, 6) == "torch.") { - fprintf(pyfp, "%s", param.as[i].c_str()); - } else { - fprintf(pyfp, "\'%s\'", param.as[i].c_str()); - } - if (i + 1 != param.as.size() || param.as.size() == 1) fprintf(pyfp, ","); + } + else if (op->has_param("steps")) + { + std::vector steps = op->params.at("steps").ai; + int step = steps[i]; + if (step != 1) + { + r += ':'; + r += std::to_string(step); + } + } + else if (op->has_input("step")) + { + r += ':'; + r += std::string("v_") + sanitize_identifier(op->named_input("step")->name); + } + else // if (op->has_input("steps")) + { + // must be pnnx.SliceIndexes + const Operator* op_sliceindexes = op->named_input("steps")->producer; + const std::string& index = op_sliceindexes->params.at("indexes").as[i]; + if (index[0] == '@') + { + int stepi = std::stoi(index.substr(1)); + r += ':'; + r += std::string("v_") + sanitize_identifier(op_sliceindexes->inputs[stepi]->name); + } + else + { + int step = std::stoi(index); + if (step != 1) + { + r += ':'; + r += std::to_string(step); + } } - fprintf(pyfp, ")"); - } } - fprintf(pyfp, ")\n"); - } else { - fprintf(stderr, "todo %s\n", op->type.c_str()); - } + if (i + 1 != ndim) + r += ','; } - } - // return - { - fprintf(pyfp, " return "); + if (!pr.empty() && !nr.empty()) + return pr + "...," + nr; + + if (pr.empty() && !nr.empty()) + return std::string("...,") + nr; - int output_count = 0; + return pr + nr; +} + +static std::string make_index_expression(const Operator* op) +{ + fprintf(stderr, "make_index_expression %s\n", op->name.c_str()); + + std::string index_expr = op->params.at("expr").s; + + // strip out-most [ ] pair + index_expr = index_expr.substr(1, index_expr.size() - 2); + + // None,None, -> ..., + bool leading_none = false; + while (index_expr.substr(0, 5) == "None,") { - for (const Operator* op : ops) { - if (op->type == "pnnx.Output") output_count++; - } + leading_none = true; + index_expr = index_expr.substr(5); + } + if (leading_none) + { + index_expr = "...," + index_expr; } - int output_index = 0; - for (const Operator* op : ops) { - if (op->type != "pnnx.Output") continue; - - fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); - if (output_index + 1 != output_count) fprintf(pyfp, ", "); + return index_expr; +} - output_index++; +int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) +{ + FILE* pyfp = fopen(pypath.c_str(), "wb"); + if (!pyfp) + { + fprintf(stderr, "fopen %s failed\n", pypath.c_str()); + return -1; } + fprintf(pyfp, "import os\n"); + fprintf(pyfp, "import numpy as np\n"); + fprintf(pyfp, "import tempfile, zipfile\n"); + fprintf(pyfp, "import torch\n"); + fprintf(pyfp, "import torch.nn as nn\n"); + fprintf(pyfp, "import torch.nn.functional as F\n"); + fprintf(pyfp, "try:\n"); + fprintf(pyfp, " import torchvision\n"); + fprintf(pyfp, "except:\n"); + fprintf(pyfp, " pass\n"); + fprintf(pyfp, "\n"); - } - fprintf(pyfp, "\n"); + fprintf(pyfp, "class Model(nn.Module):\n"); + fprintf(pyfp, " def __init__(self):\n"); + fprintf(pyfp, " super(Model, self).__init__()\n"); - // export torchscript - { - fprintf(pyfp, "def export_torchscript():\n"); - fprintf(pyfp, " net = Model()\n"); - fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); - fprintf(pyfp, " torch.manual_seed(0)\n"); - std::vector input_names; - for (const Operator* op : ops) { - if (op->type != "pnnx.Input") continue; + // module + { + for (const Operator* op : ops) + { + if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") + continue; - const Operand* r = op->outputs[0]; - std::string input_name = std::string("v_") + sanitize_identifier(r->name); - if (type_is_integer(r->type)) { - fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); - } - fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); - } else { - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) { - fprintf(pyfp, "%d, ", r->shape[i]); - } - fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); - } + fprintf(pyfp, " self.%s = %s(", sanitize_identifier(op->name).c_str(), op->type.c_str()); - input_names.push_back(input_name); + int param_count = op->params.size(); + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + param_count -= 2; // ignore scale and zero_point + } + + int param_index = 0; + for (const auto& it : op->params) + { + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + if (it.first == "scale" || it.first == "zero_point") + continue; + } + + fprintf(pyfp, "%s=", it.first.c_str()); + + const Parameter& param = it.second; + if (param.type == 0) + { + fprintf(pyfp, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(pyfp, "True"); + else + fprintf(pyfp, "False"); + } + if (param.type == 2) + { + fprintf(pyfp, "%d", param.i); + } + if (param.type == 3) + { + fprintf(pyfp, "%f", param.f); + } + if (param.type == 4) + { + if (param.s.substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.s.c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.s.c_str()); + } + } + if (param.type == 5) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + if ((op->type == "nn.AdaptiveAvgPool2d" + || op->type == "nn.AdaptiveAvgPool3d" + || op->type == "nn.AdaptiveMaxPool2d" + || op->type == "nn.AdaptiveMaxPool3d") + && it.first == "output_size" && param.ai[i] == 0) + { + fprintf(pyfp, "None"); + } + else + { + fprintf(pyfp, "%d", param.ai[i]); + } + if (i + 1 != param.ai.size() || param.ai.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 6) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(pyfp, "%f", param.af[i]); + if (i + 1 != param.af.size() || param.af.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 7) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + if (param.as[i].substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.as[i].c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.as[i].c_str()); + } + if (i + 1 != param.as.size() || param.as.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + + param_index++; + if (param_index != param_count) + fprintf(pyfp, ", "); + } + + fprintf(pyfp, ")\n"); + } } fprintf(pyfp, "\n"); - if (input_names.size() == 1) { - fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); - } else { - fprintf(pyfp, " mod = torch.jit.trace(net, ("); + // load weights + { + fprintf(pyfp, " archive = zipfile.ZipFile('%s', 'r')\n", pnnxbinpath.c_str()); + + for (const Operator* op : ops) + { + if (op->type.substr(0, 3) != "nn." && op->type.substr(0, 16) != "torchvision.ops.") + continue; + + if (op->type == "nn.quantized.Conv2d" || op->type == "nn.quantized.Linear") + { + for (const auto& it : op->attrs) + { + if (it.first == "weight" || it.first == "bias") + { + fprintf(pyfp, " self_%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + else + { + // unknown attr + continue; + } + + const Attribute& attr = it.second; + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d", attr.shape[i]); + if (i + 1 != attr.shape.size()) + fprintf(pyfp, ","); + } + + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } + + fprintf(pyfp, " self.%s.set_weight_bias(self_%s_weight, self_%s_bias)\n", sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(op->name).c_str()); + fprintf(pyfp, " self.%s.scale = %f\n", sanitize_identifier(op->name).c_str(), op->params.at("scale").f); + fprintf(pyfp, " self.%s.zero_point = %d\n", sanitize_identifier(op->name).c_str(), op->params.at("zero_point").i); + + continue; + } - for (size_t i = 0; i < input_names.size(); i++) { - fprintf(pyfp, "%s", input_names[i].c_str()); - if (i + 1 != input_names.size()) fprintf(pyfp, ", "); - } + for (const auto& it : op->attrs) + { + if (it.first == "running_mean" || it.first == "running_var") + { + fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + else + { + fprintf(pyfp, " self.%s.%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), it.first.c_str(), op->name.c_str(), it.first.c_str()); + } + + const Attribute& attr = it.second; + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d", attr.shape[i]); + if (i + 1 != attr.shape.size()) + fprintf(pyfp, ","); + } + + if (attr.type == 1 || attr.type == 2 || attr.type == 3) + { + fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + } + else + { + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } + } + } + + for (const Operator* op : ops) + { + if (op->type != "pnnx.Attribute") + continue; + + const std::string& key = op->attrs.begin()->first; + const Attribute& attr = op->attrs.begin()->second; + + bool is_running_mean_var = false; + { + const Operand* r = op->outputs[0]; + if (r->consumers.size() == 1) + { + const Operator* op2 = r->consumers[0]; + if (op2->type == "F.batch_norm" || op2->type == "F.instance_norm") + { + if (r == op2->inputs[1] || r == op2->inputs[2]) + { + is_running_mean_var = true; + } + } + } + } + + bool is_empty = false; + for (size_t i = 0; i < attr.shape.size(); i++) + { + if (attr.shape[i] == 0) + is_empty = true; + } + + if (is_empty) + { + fprintf(pyfp, " self.%s_%s = torch.from_numpy(np.empty((", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); + + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d,", attr.shape[i]); + } - fprintf(pyfp, "))\n"); + fprintf(pyfp, "), dtype='%s'))\n", type_to_numpy_string(attr.type)); + } + else + { + if (is_running_mean_var) + { + fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); + } + else + { + fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); + } + + for (size_t i = 0; i < attr.shape.size(); i++) + { + fprintf(pyfp, "%d,", attr.shape[i]); + } + + if (attr.type == 1 || attr.type == 2 || attr.type == 3) + { + fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); + } + else + { + fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); + } + } + } + + fprintf(pyfp, " archive.close()\n"); } - fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); - } + fprintf(pyfp, "\n"); - fprintf(pyfp, "\n"); + // utility function + { + fprintf(pyfp, " def load_pnnx_bin_as_parameter(self, archive, key, shape, dtype, requires_grad=True):\n"); + fprintf(pyfp, " return nn.Parameter(self.load_pnnx_bin_as_tensor(archive, key, shape, dtype), requires_grad)\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " def load_pnnx_bin_as_tensor(self, archive, key, shape, dtype):\n"); + fprintf(pyfp, " fd, tmppath = tempfile.mkstemp()\n"); + fprintf(pyfp, " with os.fdopen(fd, 'wb') as tmpf, archive.open(key) as keyfile:\n"); + fprintf(pyfp, " tmpf.write(keyfile.read())\n"); + fprintf(pyfp, " m = np.memmap(tmppath, dtype=dtype, mode='r', shape=shape).copy()\n"); + fprintf(pyfp, " os.remove(tmppath)\n"); + fprintf(pyfp, " return torch.from_numpy(m)\n"); + } - // export onnx - { - fprintf(pyfp, "def export_onnx():\n"); - fprintf(pyfp, " net = Model()\n"); - fprintf(pyfp, " net.eval()\n"); fprintf(pyfp, "\n"); - fprintf(pyfp, " torch.manual_seed(0)\n"); - std::vector input_names; - for (const Operator* op : ops) { - if (op->type != "pnnx.Input") continue; + // def forward + { + fprintf(pyfp, " def forward(self"); + + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; - const Operand* r = op->outputs[0]; - std::string input_name = std::string("v_") + sanitize_identifier(r->name); - if (type_is_integer(r->type)) { - fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); - } - fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); - } else { - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) { - fprintf(pyfp, "%d, ", r->shape[i]); + fprintf(pyfp, ", v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); } - fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); - } - input_names.push_back(input_name); + fprintf(pyfp, "):\n"); } - fprintf(pyfp, "\n"); + // forward body + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Input" || op->type == "pnnx.Output") + continue; + + if (op->type == "pnnx.SliceIndexes") + continue; + + fprintf(pyfp, " "); + + if (op->type == "pnnx.Expression") + { + // expr + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + std::string expanded_expr = expand_expression(op); + fprintf(pyfp, " = %s\n", expanded_expr.c_str()); + } + else if (op->type == "pnnx.Attribute") + { + const std::string& key = op->attrs.begin()->first; + fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); + } + else if (op->type == "Tensor.slice") + { + // slice expr + std::string slice_expr = make_slice_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), slice_expr.c_str()); + } + else if (op->type == "Tensor.slice_copy") + { + // slice copy expr + std::string slice_expr = make_slice_expression(op); + fprintf(pyfp, "v_%s = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str()); + fprintf(pyfp, " v_%s[%s] = v_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), slice_expr.c_str(), sanitize_identifier(op->inputs[1]->name).c_str()); + } + else if (op->type == "Tensor.index") + { + // index expr + if (op->inputs.size() == 2) + { + std::string expanded_expr = expand_expression(op->inputs[1]->producer); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); + } + else + { + std::string index_expr = make_index_expression(op); + fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); + } + } + else if (op->type == "Tensor.expand") + { + // expand + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector& shape = op->params.at("shape").ai; + for (size_t i = 0; i < shape.size(); i++) + { + fprintf(pyfp, "%d", shape[i]); + if (i + 1 != shape.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") + { + // view reshape + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector& shape = op->params.at("shape").ai; + for (size_t i = 0; i < shape.size(); i++) + { + fprintf(pyfp, "%d", shape[i]); + if (i + 1 != shape.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "Tensor.repeat") + { + // view reshape + fprintf(pyfp, "v_%s = v_%s.%s(", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, "*v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + else + { + const std::vector& sizes = op->params.at("sizes").ai; + for (size_t i = 0; i < sizes.size(); i++) + { + fprintf(pyfp, "%d", sizes[i]); + if (i + 1 != sizes.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "torch.cat" || op->type == "torch.stack") + { + // cat + fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); + if (op->inputs.size() == 1) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "("); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")"); + } + fprintf(pyfp, ", dim=%d", op->params.at("dim").i); + fprintf(pyfp, ")\n"); + } + else if (op->type == "torch.einsum") + { + // einsum + fprintf(pyfp, "v_%s = %s(", sanitize_identifier(op->outputs[0]->name).c_str(), op->type.c_str()); + + fprintf(pyfp, "\'%s\'", op->params.at("equation").s.c_str()); + + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "prim::TupleUnpack") + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else if (op->type == "prim::TupleConstruct") + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + fprintf(pyfp, " = ("); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "prim::ListUnpack") + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = v_%s\n", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else if (op->type == "prim::ListConstruct") + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[0]->name).c_str()); + fprintf(pyfp, " = ["); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "]\n"); + } + else if (op->type == "nn.GRU" || op->type == "nn.RNN") + { + if (op->outputs.size() == 1) + { + fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "v_%s, v_%s", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str()); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (op->inputs.size() == 2) + { + fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[1]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "nn.LSTM") + { + if (op->outputs.size() == 1) + { + fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, "v_%s, (v_%s, v_%s)", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str(), sanitize_identifier(op->outputs[2]->name).c_str()); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (op->inputs.size() == 3) + { + fprintf(pyfp, ", (v_%s, v_%s)", sanitize_identifier(op->inputs[1]->name).c_str(), sanitize_identifier(op->inputs[2]->name).c_str()); + } + fprintf(pyfp, ")\n"); + } + else if (op->type == "nn.MultiheadAttention") + { + bool need_weights = true; + if (op->outputs.size() == 1) + { + fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str()); + need_weights = false; + } + else + { + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + if (op->inputs.size() == 1) + { + std::string in0 = sanitize_identifier(op->inputs[0]->name); + fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in0.c_str(), in0.c_str()); + } + else if (op->inputs.size() == 2) + { + std::string in0 = sanitize_identifier(op->inputs[0]->name); + std::string in1 = sanitize_identifier(op->inputs[1]->name); + if (op->inputnames.size() == 2 && op->inputnames[1] == "attn_mask") + { + fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in0.c_str(), in0.c_str(), in1.c_str()); + } + else + { + fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in1.c_str()); + } + } + else if (op->inputs.size() == 3) + { + std::string in0 = sanitize_identifier(op->inputs[0]->name); + std::string in1 = sanitize_identifier(op->inputs[1]->name); + std::string in2 = sanitize_identifier(op->inputs[2]->name); + if (op->inputnames.size() == 3 && op->inputnames[2] == "attn_mask") + { + fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in1.c_str(), in1.c_str(), in2.c_str()); + } + else + { + fprintf(pyfp, "v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in2.c_str()); + } + } + else if (op->inputs.size() == 4) + { + std::string in0 = sanitize_identifier(op->inputs[0]->name); + std::string in1 = sanitize_identifier(op->inputs[1]->name); + std::string in2 = sanitize_identifier(op->inputs[2]->name); + std::string in3 = sanitize_identifier(op->inputs[3]->name); + if (op->inputnames.size() == 4 && op->inputnames[3] == "attn_mask") + { + fprintf(pyfp, "v_%s, v_%s, v_%s, attn_mask=v_%s", in0.c_str(), in1.c_str(), in2.c_str(), in3.c_str()); + } + else + { + fprintf(pyfp, "v_%s, v_%s, v_%s, v_%s", in0.c_str(), in1.c_str(), in2.c_str(), in3.c_str()); + } + } + else + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + if (need_weights) + { + fprintf(pyfp, ", need_weights=True"); + } + else + { + fprintf(pyfp, ", need_weights=False"); + } + fprintf(pyfp, ")\n"); + } + else if (op->type.substr(0, 3) == "nn." || op->type.substr(0, 16) == "torchvision.ops.") + { + // self.xxx() + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str()); + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + fprintf(pyfp, ")\n"); + } + else + { + if (op->type.find("::") == std::string::npos && op->type.find(".") == std::string::npos) + { + fprintf(stderr, "todo %s\n", op->type.c_str()); + } + + // direct + for (size_t i = 0; i < op->outputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->outputs[i]->name).c_str()); + if (i + 1 != op->outputs.size()) + fprintf(pyfp, ", "); + } + + if (op->type == "torch.max" || op->type == "torch.max") + { + if (op->has_param("dim") && op->outputs.size() == 1) + { + // torch.max and torch.min with dim returns tuple + fprintf(pyfp, ", _"); + } + } + + if (op->type.substr(0, 7) == "Tensor.") + { + if (op->type == "Tensor.fill") + { + fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str()); + } + else + { + fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str()); + } + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 1; i < op->inputs.size(); i++) + { + if (!op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } + + for (size_t i = 1; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); + } + } + else + { + for (size_t i = 1; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str()); + } + } + } + else + { + fprintf(pyfp, " = %s(", op->type.c_str()); + + if (op->inputnames.size() == op->inputs.size()) + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (!op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputnames[i].empty()) + continue; + + fprintf(pyfp, "%s=v_%s", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + else + { + for (size_t i = 0; i < op->inputs.size(); i++) + { + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[i]->name).c_str()); + if (i + 1 != op->inputs.size()) + fprintf(pyfp, ", "); + } + } + } + + int i = 0; + for (const auto& it : op->params) + { + if (op->type.substr(0, 7) == "Tensor." && i == 0) + { + fprintf(pyfp, "%s=", it.first.c_str()); + } + else if (op->inputs.empty() && i == 0) + { + fprintf(pyfp, "%s=", it.first.c_str()); + } + else + { + fprintf(pyfp, ", %s=", it.first.c_str()); + } + + i++; + + const Parameter& param = it.second; + if (param.type == 0) + { + if (op->type == "Tensor.index_put" && it.first == "values") + { + fprintf(pyfp, "torch.tensor(False)"); + } + else + { + fprintf(pyfp, "None"); + } + } + if (param.type == 1) + { + if (param.b) + fprintf(pyfp, "True"); + else + fprintf(pyfp, "False"); + } + if (param.type == 2) + { + if (op->type == "Tensor.index_put" && it.first == "values") + { + fprintf(pyfp, "torch.tensor(%d)", param.i); + } + else + { + fprintf(pyfp, "%d", param.i); + } + } + if (param.type == 3) + { + if (op->type == "Tensor.index_put" && it.first == "values") + { + fprintf(pyfp, "torch.tensor(%f)", param.f); + } + else + { + fprintf(pyfp, "%f", param.f); + } + } + if (param.type == 4) + { + if (param.s.substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.s.c_str()); + } + else if (op->type == "Tensor.index_put" && it.first == "values") + { + if (param.s == "inf" || param.s == "-inf") + { + fprintf(pyfp, "torch.tensor(float(\'%s\'))", param.s.c_str()); + } + else + { + fprintf(pyfp, "torch.tensor(\'%s\')", param.s.c_str()); + } + } + else + { + if (param.s == "inf" || param.s == "-inf") + { + fprintf(pyfp, "float(\'%s\')", param.s.c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.s.c_str()); + } + } + } + if (param.type == 5) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + if ((op->type == "F.adaptive_avg_pool2d" + || op->type == "F.adaptive_avg_pool3d" + || op->type == "F.adaptive_max_pool2d" + || op->type == "F.adaptive_max_pool3d") + && it.first == "output_size" && param.ai[i] == 0) + { + fprintf(pyfp, "None"); + } + else + { + fprintf(pyfp, "%d", param.ai[i]); + } + if (i + 1 != param.ai.size() || param.ai.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 6) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(pyfp, "%f", param.af[i]); + if (i + 1 != param.af.size() || param.af.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 7) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + if (param.as[i].substr(0, 6) == "torch.") + { + fprintf(pyfp, "%s", param.as[i].c_str()); + } + else + { + fprintf(pyfp, "\'%s\'", param.as[i].c_str()); + } + if (i + 1 != param.as.size() || param.as.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + if (param.type == 10) + { + fprintf(pyfp, "(%f%+fj)", param.c.real(), param.c.imag()); + } + if (param.type == 11) + { + fprintf(pyfp, "("); + for (size_t i = 0; i < param.ac.size(); i++) + { + fprintf(pyfp, "(%f%+fj)", param.ac[i].real(), param.ac[i].imag()); + if (i + 1 != param.ac.size() || param.ac.size() == 1) + fprintf(pyfp, ","); + } + fprintf(pyfp, ")"); + } + } + + fprintf(pyfp, ")\n"); + } + } + } - // torch.onnx._export(net, v_0, "test_swin_t.onnx", export_params=True, opset_version=14, - // input_names=['in0'], output_names=['out0']) + // return + { + fprintf(pyfp, " return "); + + int output_count = 0; + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Output") + output_count++; + } + } + + int output_index = 0; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Output") + continue; - if (input_names.size() == 1) { - fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str()); - } else { - fprintf(pyfp, " torch.onnx._export(net, ("); + fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str()); + if (output_index + 1 != output_count) + fprintf(pyfp, ", "); - for (size_t i = 0; i < input_names.size(); i++) { - fprintf(pyfp, "%s", input_names[i].c_str()); - if (i + 1 != input_names.size()) fprintf(pyfp, ", "); - } + output_index++; + } - fprintf(pyfp, ")"); + fprintf(pyfp, "\n"); } - fprintf( - pyfp, - ", \"%s.onnx\", export_params=True, " - "operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", - pypath.c_str()); + fprintf(pyfp, "\n"); - fprintf(pyfp, ", input_names=["); + // export torchscript { - int input_count = 0; - { - for (const Operator* op : ops) { - if (op->type == "pnnx.Input") input_count++; + fprintf(pyfp, "def export_torchscript():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + if (type_is_integer(r->type)) + { + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); + } + + input_names.push_back(input_name); } - } - int input_index = 0; - for (const Operator* op : ops) { - if (op->type != "pnnx.Input") continue; + fprintf(pyfp, "\n"); - fprintf(pyfp, "'in%d'", input_index); - if (input_index + 1 != input_count) fprintf(pyfp, ", "); + if (input_names.size() == 1) + { + fprintf(pyfp, " mod = torch.jit.trace(net, %s)\n", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " mod = torch.jit.trace(net, ("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } - input_index++; - } + fprintf(pyfp, "))\n"); + } + + fprintf(pyfp, " mod.save(\"%s.pt\")\n", pypath.c_str()); } - fprintf(pyfp, "]"); - fprintf(pyfp, ", output_names=["); + fprintf(pyfp, "\n"); + + // export onnx { - int output_count = 0; - { - for (const Operator* op : ops) { - if (op->type == "pnnx.Output") output_count++; + fprintf(pyfp, "def export_onnx():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + if (type_is_integer(r->type)) + { + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); + } + + input_names.push_back(input_name); } - } - int output_index = 0; - for (const Operator* op : ops) { - if (op->type != "pnnx.Output") continue; + fprintf(pyfp, "\n"); - fprintf(pyfp, "'out%d'", output_index); - if (output_index + 1 != output_count) fprintf(pyfp, ", "); + // torch.onnx._export(net, v_0, "test_swin_t.onnx", export_params=True, opset_version=14, input_names=['in0'], output_names=['out0']) - output_index++; - } - } - fprintf(pyfp, "]"); + if (input_names.size() == 1) + { + fprintf(pyfp, " torch.onnx._export(net, %s", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " torch.onnx._export(net, ("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } - fprintf(pyfp, ")\n"); - } + fprintf(pyfp, ")"); + } - fprintf(pyfp, "\n"); + fprintf(pyfp, ", \"%s.onnx\", export_params=True, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, opset_version=13", pypath.c_str()); + + fprintf(pyfp, ", input_names=["); + { + int input_count = 0; + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Input") + input_count++; + } + } - // test inference - { - fprintf(pyfp, "def test_inference():\n"); - fprintf(pyfp, " net = Model()\n"); - fprintf(pyfp, " net.eval()\n"); - fprintf(pyfp, "\n"); - fprintf(pyfp, " torch.manual_seed(0)\n"); + int input_index = 0; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; - std::vector input_names; - for (const Operator* op : ops) { - if (op->type != "pnnx.Input") continue; + fprintf(pyfp, "'in%d'", input_index); + if (input_index + 1 != input_count) + fprintf(pyfp, ", "); - const Operand* r = op->outputs[0]; - std::string input_name = std::string("v_") + sanitize_identifier(r->name); - if (type_is_integer(r->type)) { - fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size() || r->shape.size() == 1) fprintf(pyfp, ", "); + input_index++; + } } - fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); - } else { - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) { - fprintf(pyfp, "%d, ", r->shape[i]); + fprintf(pyfp, "]"); + + fprintf(pyfp, ", output_names=["); + { + int output_count = 0; + { + for (const Operator* op : ops) + { + if (op->type == "pnnx.Output") + output_count++; + } + } + + int output_index = 0; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Output") + continue; + + fprintf(pyfp, "'out%d'", output_index); + if (output_index + 1 != output_count) + fprintf(pyfp, ", "); + + output_index++; + } } - fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); - } + fprintf(pyfp, "]"); - input_names.push_back(input_name); + fprintf(pyfp, ")\n"); } fprintf(pyfp, "\n"); - if (input_names.size() == 1) { - fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); - } else { - fprintf(pyfp, " return net("); + // test inference + { + fprintf(pyfp, "def test_inference():\n"); + fprintf(pyfp, " net = Model()\n"); + fprintf(pyfp, " net.eval()\n"); + fprintf(pyfp, "\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + std::vector input_names; + for (const Operator* op : ops) + { + if (op->type != "pnnx.Input") + continue; + + const Operand* r = op->outputs[0]; + std::string input_name = std::string("v_") + sanitize_identifier(r->name); + if (type_is_integer(r->type)) + { + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); + } + + input_names.push_back(input_name); + } + + fprintf(pyfp, "\n"); - for (size_t i = 0; i < input_names.size(); i++) { - fprintf(pyfp, "%s", input_names[i].c_str()); - if (i + 1 != input_names.size()) fprintf(pyfp, ", "); - } + if (input_names.size() == 1) + { + fprintf(pyfp, " return net(%s)\n", input_names[0].c_str()); + } + else + { + fprintf(pyfp, " return net("); + + for (size_t i = 0; i < input_names.size(); i++) + { + fprintf(pyfp, "%s", input_names[i].c_str()); + if (i + 1 != input_names.size()) + fprintf(pyfp, ", "); + } - fprintf(pyfp, ")\n"); + fprintf(pyfp, ")\n"); + } } - } - fclose(pyfp); + fprintf(pyfp, "\n"); - return 0; -} + // main + { + fprintf(pyfp, "if __name__ == \"__main__\":\n"); + fprintf(pyfp, " print(test_inference())\n"); + } -int Graph::parse(const std::string& param) { - std::istringstream is(param); - if (!is.good()) { - fprintf(stderr, "open failed\n"); - return -1; - } - - int magic = 0; - { - std::string line; - std::getline(is, line); - std::istringstream iss(line); - - iss >> magic; - } - - int operator_count = 0; - int operand_count = 0; - { - std::string line; - std::getline(is, line); - std::istringstream iss(line); - - iss >> operator_count >> operand_count; - } - - for (int i = 0; i < operator_count; i++) { - std::string line; - std::getline(is, line); - std::istringstream iss(line); - - std::string type; - std::string name; - int input_count = 0; - int output_count = 0; - - iss >> type >> name >> input_count >> output_count; - - Operator* op = new_operator(type, name); - - for (int j = 0; j < input_count; j++) { - std::string operand_name; - iss >> operand_name; - - Operand* r = get_operand(operand_name); - r->consumers.push_back(op); - op->inputs.push_back(r); - } - - for (int j = 0; j < output_count; j++) { - std::string operand_name; - iss >> operand_name; - - Operand* r = new_operand(operand_name); - r->producer = op; - op->outputs.push_back(r); - } - - // key=value - while (!iss.eof()) { - std::string param; - iss >> param; - - std::string key; - std::string value; - std::istringstream pss(param); - std::getline(pss, key, '='); - std::getline(pss, value); - - if (key[0] == '@') { - // attribute - // load_attribute(op, key.substr(1), value, szr); - } else if (key[0] == '$') { - // operand input key - // load_input_key(op, key.substr(1), value); - } else if (key[0] == '#') { - // operand shape - load_shape(op, key.substr(1), value); - } else { - // parameter - load_parameter(op, key, value); - } - } - } - - return 0; -} + fclose(pyfp); -void Operand::remove_consumer(const Operator* c) { - auto it = std::find(consumers.begin(), consumers.end(), c); - consumers.erase(it); + return 0; } -Operator* Graph::new_operator(const std::string& type, const std::string& name) { - Operator* op = new Operator; - op->type = type; - op->name = name; - ops.push_back(op); - return op; +int Graph::parse(const std::string& param) +{ + std::istringstream is(param); + if (!is.good()) + { + fprintf(stderr, "open failed\n"); + return -1; + } + + int magic = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> magic; + } + + int operator_count = 0; + int operand_count = 0; + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + iss >> operator_count >> operand_count; + } + + for (int i = 0; i < operator_count; i++) + { + std::string line; + std::getline(is, line); + std::istringstream iss(line); + + std::string type; + std::string name; + int input_count = 0; + int output_count = 0; + + iss >> type >> name >> input_count >> output_count; + + Operator* op = new_operator(type, name); + + for (int j = 0; j < input_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand* r = get_operand(operand_name); + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < output_count; j++) + { + std::string operand_name; + iss >> operand_name; + + Operand* r = new_operand(operand_name); + r->producer = op; + op->outputs.push_back(r); + } + + // key=value + while (!iss.eof()) + { + std::string param; + iss >> param; + + std::string key; + std::string value; + std::istringstream pss(param); + std::getline(pss, key, '='); + std::getline(pss, value); + + if (key[0] == '@') + { + // attribute + // load_attribute(op, key.substr(1), value, szr); + op->attrs[key.substr(1)] = Attribute(); + + Attribute& attr = op->attrs[key.substr(1)]; + + attr.type = 0; + if (value.empty()) + continue; + + if (value[0] == '%') + { + // @data=%op1.data + attr.data = std::vector(value.begin(), value.end()); + } + + if (value[0] == '(') + { + // @data=(1,%c,?,4)f32 + + // type + std::string typestr = value.substr(value.find_last_of(')') + 1); + attr.type = string_to_type(typestr.c_str()); + + // shape + std::string lc = value.substr(1, value.find_last_of(')') - 1); + std::istringstream lcss(lc); + + attr.shape.clear(); + while (!lcss.eof()) + { + std::string elem; + std::getline(lcss, elem, ','); + + if (elem == "?") + { + attr.shape.push_back(-1); + } + else if (elem[0] == '%') + { + // encode %abc as symbolic tag + attr.shape.push_back(-233); + int index = attr.shape.size() - 1; + std::string key = elem.substr(1); + attr.params[std::string("__shape__") + std::to_string(index)] = key; + } + else + { + int i = std::stoi(elem); + attr.shape.push_back(i); + } + } + } + } + else if (key[0] == '$') + { + // operand input key + load_input_key(op, key.substr(1), value); + } + else if (key[0] == '#') + { + // operand shape + load_shape(op, key.substr(1), value); + } + else + { + // parameter + load_parameter(op, key, value); + } + } + } + + return 0; } -Operator* Graph::new_operator_before(const std::string& type, const std::string& name, - const Operator* cur) { - Operator* op = new Operator; - op->type = type; - op->name = name; - ops.insert(std::find(ops.begin(), ops.end(), cur), op); - return op; +void Operand::remove_consumer(const Operator* c) +{ + auto it = std::find(consumers.begin(), consumers.end(), c); + if (it != consumers.end()) + consumers.erase(it); } -Operator* Graph::new_operator_after(const std::string& type, const std::string& name, - const Operator* cur) { - Operator* op = new Operator; - op->type = type; - op->name = name; - ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op); - return op; +Operator* Graph::new_operator(const std::string& type, const std::string& name) +{ + Operator* op = new Operator; + op->type = type; + op->name = name; + ops.push_back(op); + return op; } -#if BUILD_PNNX -Operand* Graph::new_operand(const torch::jit::Value* v) { - Operand* r = new Operand; - r->name = v->debugName(); - - auto pt = v->type()->cast(); - if (pt) { - if (pt->scalarType().has_value() && pt->dim().has_value()) { - r->type = get_at_tensor_type(pt->scalarType().value()); - const int ndim = (int)pt->dim().value(); - r->shape.resize(ndim); - for (int i = 0; i < ndim; i++) { - if (pt->sizes()[i].has_value()) - r->shape[i] = (int)pt->sizes()[i].value(); - else - r->shape[i] = -1; - } - } - } +Operator* Graph::new_operator_before(const std::string& type, const std::string& name, const Operator* cur) +{ + Operator* op = new Operator; + op->type = type; + op->name = name; + ops.insert(std::find(ops.begin(), ops.end(), cur), op); + return op; +} - operands.push_back(r); - return r; +Operator* Graph::new_operator_after(const std::string& type, const std::string& name, const Operator* cur) +{ + Operator* op = new Operator; + op->type = type; + op->name = name; + ops.insert(std::find(ops.begin(), ops.end(), cur) + 1, op); + return op; } -#endif // BUILD_PNNX -Operand* Graph::new_operand(const std::string& name) { - Operand* r = new Operand; - r->name = name; - operands.push_back(r); - return r; +Operand* Graph::new_operand(const std::string& name) +{ + Operand* r = new Operand; + r->name = name; + operands.push_back(r); + return r; } -Operand* Graph::get_operand(const std::string& name) { - for (Operand* r : operands) { - if (r->name == name) return r; - } +Operand* Graph::get_operand(const std::string& name) +{ + for (Operand* r : operands) + { + if (r->name == name) + return r; + } - return 0; + return 0; } -const Operand* Graph::get_operand(const std::string& name) const { - for (const Operand* r : operands) { - if (r->name == name) return r; - } +const Operand* Graph::get_operand(const std::string& name) const +{ + for (const Operand* r : operands) + { + if (r->name == name) + return r; + } - return 0; + return 0; } -} // namespace pnnx +} // namespace pnnx diff --git a/source/runtime/pnnx/store_zip.cpp b/source/runtime/pnnx/store_zip.cpp index 1eedfcb7..8e354f7b 100644 --- a/source/runtime/pnnx/store_zip.cpp +++ b/source/runtime/pnnx/store_zip.cpp @@ -3,7 +3,7 @@ // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except -// in compliance with the License. You may obtain a copy of the License slice +// in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // @@ -13,8 +13,9 @@ // specific language governing permissions and limitations under the License. #include "runtime/pnnx/store_zip.hpp" -#include + #include +#include #include #include #include @@ -29,324 +30,527 @@ namespace pnnx { #endif PACK(struct local_file_header { - uint16_t version; - uint16_t flag; - uint16_t compression; - uint16_t last_modify_time; - uint16_t last_modify_date; - uint32_t crc32; - uint32_t compressed_size; - uint32_t uncompressed_size; - uint16_t file_name_length; - uint16_t extra_field_length; + uint16_t version; + uint16_t flag; + uint16_t compression; + uint16_t last_modify_time; + uint16_t last_modify_date; + uint32_t crc32; + uint32_t compressed_size; + uint32_t uncompressed_size; + uint16_t file_name_length; + uint16_t extra_field_length; +}); + +PACK(struct zip64_extended_extra_field { + uint64_t uncompressed_size; + uint64_t compressed_size; + uint64_t lfh_offset; + uint32_t disk_number; }); PACK(struct central_directory_file_header { - uint16_t version_made; - uint16_t version; - uint16_t flag; - uint16_t compression; - uint16_t last_modify_time; - uint16_t last_modify_date; - uint32_t crc32; - uint32_t compressed_size; - uint32_t uncompressed_size; - uint16_t file_name_length; - uint16_t extra_field_length; - uint16_t file_comment_length; - uint16_t start_disk; - uint16_t internal_file_attrs; - uint32_t external_file_attrs; - uint32_t lfh_offset; + uint16_t version_made; + uint16_t version; + uint16_t flag; + uint16_t compression; + uint16_t last_modify_time; + uint16_t last_modify_date; + uint32_t crc32; + uint32_t compressed_size; + uint32_t uncompressed_size; + uint16_t file_name_length; + uint16_t extra_field_length; + uint16_t file_comment_length; + uint16_t start_disk; + uint16_t internal_file_attrs; + uint32_t external_file_attrs; + uint32_t lfh_offset; +}); + +PACK(struct zip64_end_of_central_directory_record { + uint64_t size_of_eocd64_m12; + uint16_t version_made_by; + uint16_t version_min_required; + uint32_t disk_number; + uint32_t start_disk; + uint64_t cd_records; + uint64_t total_cd_records; + uint64_t cd_size; + uint64_t cd_offset; +}); + +PACK(struct zip64_end_of_central_directory_locator { + uint32_t eocdr64_disk_number; + uint64_t eocdr64_offset; + uint32_t disk_count; }); PACK(struct end_of_central_directory_record { - uint16_t disk_number; - uint16_t start_disk; - uint16_t cd_records; - uint16_t total_cd_records; - uint32_t cd_size; - uint32_t cd_offset; - uint16_t comment_length; + uint16_t disk_number; + uint16_t start_disk; + uint16_t cd_records; + uint16_t total_cd_records; + uint32_t cd_size; + uint32_t cd_offset; + uint16_t comment_length; }); static uint32_t CRC32_TABLE[256]; -static void CRC32_TABLE_INIT() { - for (int i = 0; i < 256; i++) { - uint32_t c = i; - for (int j = 0; j < 8; j++) { - if (c & 1) - c = (c >> 1) ^ 0xedb88320; - else - c >>= 1; +static void CRC32_TABLE_INIT() +{ + for (int i = 0; i < 256; i++) + { + uint32_t c = i; + for (int j = 0; j < 8; j++) + { + if (c & 1) + c = (c >> 1) ^ 0xedb88320; + else + c >>= 1; + } + CRC32_TABLE[i] = c; } - CRC32_TABLE[i] = c; - } } -static uint32_t CRC32(uint32_t x, unsigned char ch) { - return (x >> 8) ^ CRC32_TABLE[(x ^ ch) & 0xff]; +static uint32_t CRC32(uint32_t x, unsigned char ch) +{ + return (x >> 8) ^ CRC32_TABLE[(x ^ ch) & 0xff]; } -static uint32_t CRC32_buffer(const unsigned char* data, int len) { - uint32_t x = 0xffffffff; +static uint32_t CRC32_buffer(const unsigned char* data, uint64_t len) +{ + uint32_t x = 0xffffffff; - for (int i = 0; i < len; i++) x = CRC32(x, data[i]); + for (uint64_t i = 0; i < len; i++) + x = CRC32(x, data[i]); - return x ^ 0xffffffff; + return x ^ 0xffffffff; } -StoreZipReader::StoreZipReader() { fp = 0; } - -StoreZipReader::~StoreZipReader() { close(); } - -int StoreZipReader::open(const std::string& path) { - close(); - - fp = fopen(path.c_str(), "rb"); - if (!fp) { - fprintf(stderr, "open failed\n"); - return -1; - } +StoreZipReader::StoreZipReader() +{ + fp = 0; +} - while (!feof(fp)) { - // peek signature - uint32_t signature; - int nread = fread((char*)&signature, sizeof(signature), 1, fp); - if (nread != 1) break; +StoreZipReader::~StoreZipReader() +{ + close(); +} - if (signature == 0x04034b50) { - local_file_header lfh; - fread((char*)&lfh, sizeof(lfh), 1, fp); +int StoreZipReader::open(const std::string& path) +{ + close(); - if (lfh.flag & 0x08) { - fprintf(stderr, "zip file contains data descriptor, this is not supported yet\n"); + fp = fopen(path.c_str(), "rb"); + if (!fp) + { + fprintf(stderr, "open failed\n"); return -1; - } + } - if (lfh.compression != 0 || lfh.compressed_size != lfh.uncompressed_size) { - fprintf(stderr, "not stored zip file %d %d\n", lfh.compressed_size, lfh.uncompressed_size); - return -1; - } + while (!feof(fp)) + { + // peek signature + uint32_t signature; + int nread = fread((char*)&signature, sizeof(signature), 1, fp); + if (nread != 1) + break; + + // fprintf(stderr, "signature = %x\n", signature); + + if (signature == 0x04034b50) + { + local_file_header lfh; + fread((char*)&lfh, sizeof(lfh), 1, fp); + + if (lfh.flag & 0x08) + { + fprintf(stderr, "zip file contains data descriptor, this is not supported yet\n"); + return -1; + } + + if (lfh.compression != 0 || lfh.compressed_size != lfh.uncompressed_size) + { + fprintf(stderr, "not stored zip file %d %d\n", lfh.compressed_size, lfh.uncompressed_size); + return -1; + } + + // file name + std::string name; + name.resize(lfh.file_name_length); + fread((char*)name.data(), name.size(), 1, fp); + + uint64_t compressed_size = lfh.compressed_size; + uint64_t uncompressed_size = lfh.uncompressed_size; + if (compressed_size == 0xffffffff && uncompressed_size == 0xffffffff) + { + uint16_t extra_offset = 0; + while (extra_offset < lfh.extra_field_length) + { + uint16_t extra_id; + uint16_t extra_size; + fread((char*)&extra_id, sizeof(extra_id), 1, fp); + fread((char*)&extra_size, sizeof(extra_size), 1, fp); + if (extra_id != 0x0001) + { + // skip this extra field block + fseek(fp, extra_size - 4, SEEK_CUR); + extra_offset += extra_size; + continue; + } + + // zip64 extra field + zip64_extended_extra_field zip64_eef; + fread((char*)&zip64_eef, sizeof(zip64_eef), 1, fp); + + compressed_size = zip64_eef.compressed_size; + uncompressed_size = zip64_eef.uncompressed_size; + + // skip remaining extra field blocks + fseek(fp, lfh.extra_field_length - extra_offset - 4 - sizeof(zip64_eef), SEEK_CUR); + break; + } + } + else + { + // skip extra field + fseek(fp, lfh.extra_field_length, SEEK_CUR); + } + + StoreZipMeta fm; + fm.offset = ftell(fp); + fm.size = compressed_size; + + filemetas[name] = fm; + + // fprintf(stderr, "%s = %d %d\n", name.c_str(), fm.offset, fm.size); + + fseek(fp, compressed_size, SEEK_CUR); + } + else if (signature == 0x02014b50) + { + central_directory_file_header cdfh; + fread((char*)&cdfh, sizeof(cdfh), 1, fp); + + // skip file name + fseek(fp, cdfh.file_name_length, SEEK_CUR); + + // skip extra field + fseek(fp, cdfh.extra_field_length, SEEK_CUR); + + // skip file comment + fseek(fp, cdfh.file_comment_length, SEEK_CUR); + } + else if (signature == 0x06054b50) + { + end_of_central_directory_record eocdr; + fread((char*)&eocdr, sizeof(eocdr), 1, fp); + + // skip comment + fseek(fp, eocdr.comment_length, SEEK_CUR); + } + else if (signature == 0x06064b50) + { + zip64_end_of_central_directory_record eocdr64; + fread((char*)&eocdr64, sizeof(eocdr64), 1, fp); + + // skip comment + fseek(fp, eocdr64.size_of_eocd64_m12 - 44, SEEK_CUR); + } + else if (signature == 0x07064b50) + { + zip64_end_of_central_directory_locator eocdl64; + fread((char*)&eocdl64, sizeof(eocdl64), 1, fp); + } + else + { + fprintf(stderr, "unsupported signature %x\n", signature); + return -1; + } + } - // file name - std::string name; - name.resize(lfh.file_name_length); - fread((char*)name.data(), name.size(), 1, fp); + return 0; +} - // skip extra field - fseek(fp, lfh.extra_field_length, SEEK_CUR); +std::vector StoreZipReader::get_names() const +{ + std::vector names; + for (std::map::const_iterator it = filemetas.begin(); it != filemetas.end(); ++it) + { + names.push_back(it->first); + } + + return names; +} - StoreZipMeta fm; - fm.offset = ftell(fp); - fm.size = lfh.compressed_size; +uint64_t StoreZipReader::get_file_size(const std::string& name) const +{ + if (filemetas.find(name) == filemetas.end()) + { + fprintf(stderr, "no such file %s\n", name.c_str()); + return 0; + } - filemetas[name] = fm; + return filemetas.at(name).size; +} - // fprintf(stderr, "%s = %d %d\n", name.c_str(), fm.offset, fm.size); +int StoreZipReader::read_file(const std::string& name, char* data) +{ + if (filemetas.find(name) == filemetas.end()) + { + fprintf(stderr, "no such file %s\n", name.c_str()); + return -1; + } - fseek(fp, lfh.compressed_size, SEEK_CUR); - } else if (signature == 0x02014b50) { - central_directory_file_header cdfh; - fread((char*)&cdfh, sizeof(cdfh), 1, fp); + uint64_t offset = filemetas[name].offset; + uint64_t size = filemetas[name].size; - // skip file name - fseek(fp, cdfh.file_name_length, SEEK_CUR); + fseek(fp, offset, SEEK_SET); + fread(data, size, 1, fp); - // skip extra field - fseek(fp, cdfh.extra_field_length, SEEK_CUR); + return 0; +} - // skip file comment - fseek(fp, cdfh.file_comment_length, SEEK_CUR); - } else if (signature == 0x06054b50) { - end_of_central_directory_record eocdr; - fread((char*)&eocdr, sizeof(eocdr), 1, fp); +int StoreZipReader::close() +{ + if (!fp) + return 0; - // skip comment - fseek(fp, eocdr.comment_length, SEEK_CUR); - } else { - fprintf(stderr, "unsupported signature %x\n", signature); - return -1; - } - } + fclose(fp); + fp = 0; - return 0; + return 0; } -size_t StoreZipReader::get_file_size(const std::string& name) { - if (filemetas.find(name) == filemetas.end()) { - fprintf(stderr, "no such file %s\n", name.c_str()); - return 0; - } +StoreZipWriter::StoreZipWriter() +{ + fp = 0; - return filemetas[name].size; + CRC32_TABLE_INIT(); } -int StoreZipReader::read_file(const std::string& name, char* data) { - if (filemetas.find(name) == filemetas.end()) { - fprintf(stderr, "no such file %s\n", name.c_str()); - return -1; - } +StoreZipWriter::~StoreZipWriter() +{ + close(); +} - size_t offset = filemetas[name].offset; - size_t size = filemetas[name].size; +int StoreZipWriter::open(const std::string& path) +{ + close(); - fseek(fp, offset, SEEK_SET); - fread(data, size, 1, fp); + fp = fopen(path.c_str(), "wb"); + if (!fp) + { + fprintf(stderr, "open failed\n"); + return -1; + } - return 0; + return 0; } -int StoreZipReader::close() { - if (!fp) return 0; +int StoreZipWriter::write_file(const std::string& name, const char* data, uint64_t size) +{ + long offset = ftell(fp); - fclose(fp); - fp = 0; + uint32_t signature = 0x04034b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); - return 0; -} + uint32_t crc32 = CRC32_buffer((const unsigned char*)data, size); -StoreZipWriter::StoreZipWriter() { - fp = 0; + local_file_header lfh; + lfh.version = 0; + lfh.flag = 0; + lfh.compression = 0; + lfh.last_modify_time = 0; + lfh.last_modify_date = 0; + lfh.crc32 = crc32; + lfh.compressed_size = 0xffffffff; + lfh.uncompressed_size = 0xffffffff; + lfh.file_name_length = name.size(); - CRC32_TABLE_INIT(); -} + // zip64 extra field + zip64_extended_extra_field zip64_eef; + zip64_eef.uncompressed_size = size; + zip64_eef.compressed_size = size; + zip64_eef.lfh_offset = 0; + zip64_eef.disk_number = 0; -StoreZipWriter::~StoreZipWriter() { close(); } + uint16_t extra_id = 0x0001; + uint16_t extra_size = sizeof(zip64_eef); -int StoreZipWriter::open(const std::string& path) { - close(); + lfh.extra_field_length = sizeof(extra_id) + sizeof(extra_size) + sizeof(zip64_eef); - fp = fopen(path.c_str(), "wb"); - if (!fp) { - fprintf(stderr, "open failed\n"); - return -1; - } + fwrite((char*)&lfh, sizeof(lfh), 1, fp); - return 0; -} + fwrite((char*)name.c_str(), name.size(), 1, fp); -int StoreZipWriter::write_file(const std::string& name, const char* data, size_t size) { - int offset = ftell(fp); + fwrite((char*)&extra_id, sizeof(extra_id), 1, fp); + fwrite((char*)&extra_size, sizeof(extra_size), 1, fp); + fwrite((char*)&zip64_eef, sizeof(zip64_eef), 1, fp); - uint32_t signature = 0x04034b50; - fwrite((char*)&signature, sizeof(signature), 1, fp); + fwrite(data, size, 1, fp); - uint32_t crc32 = CRC32_buffer((const unsigned char*)data, size); + StoreZipMeta szm; + szm.name = name; + szm.lfh_offset = offset; + szm.crc32 = crc32; + szm.size = size; - local_file_header lfh; - lfh.version = 0; - lfh.flag = 0; - lfh.compression = 0; - lfh.last_modify_time = 0; - lfh.last_modify_date = 0; - lfh.crc32 = crc32; - lfh.compressed_size = size; - lfh.uncompressed_size = size; - lfh.file_name_length = name.size(); - lfh.extra_field_length = 0; + filemetas.push_back(szm); - fwrite((char*)&lfh, sizeof(lfh), 1, fp); + return 0; +} - fwrite((char*)name.c_str(), name.size(), 1, fp); +int StoreZipWriter::close() +{ + if (!fp) + return 0; + + long offset = ftell(fp); + + for (const StoreZipMeta& szm : filemetas) + { + uint32_t signature = 0x02014b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); + + central_directory_file_header cdfh; + cdfh.version_made = 0; + cdfh.version = 0; + cdfh.flag = 0; + cdfh.compression = 0; + cdfh.last_modify_time = 0; + cdfh.last_modify_date = 0; + cdfh.crc32 = szm.crc32; + cdfh.compressed_size = 0xffffffff; + cdfh.uncompressed_size = 0xffffffff; + cdfh.file_name_length = szm.name.size(); + cdfh.file_comment_length = 0; + cdfh.start_disk = 0xffff; + cdfh.internal_file_attrs = 0; + cdfh.external_file_attrs = 0; + cdfh.lfh_offset = 0xffffffff; + + // zip64 extra field + zip64_extended_extra_field zip64_eef; + zip64_eef.uncompressed_size = szm.size; + zip64_eef.compressed_size = szm.size; + zip64_eef.lfh_offset = szm.lfh_offset; + zip64_eef.disk_number = 0; + + uint16_t extra_id = 0x0001; + uint16_t extra_size = sizeof(zip64_eef); + + cdfh.extra_field_length = sizeof(extra_id) + sizeof(extra_size) + sizeof(zip64_eef); + + fwrite((char*)&cdfh, sizeof(cdfh), 1, fp); + + fwrite((char*)szm.name.c_str(), szm.name.size(), 1, fp); + + fwrite((char*)&extra_id, sizeof(extra_id), 1, fp); + fwrite((char*)&extra_size, sizeof(extra_size), 1, fp); + fwrite((char*)&zip64_eef, sizeof(zip64_eef), 1, fp); + } - fwrite(data, size, 1, fp); + long offset2 = ftell(fp); - StoreZipMeta szm; - szm.name = name; - szm.lfh_offset = offset; - szm.crc32 = crc32; - szm.size = size; + { + uint32_t signature = 0x06064b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); - filemetas.push_back(szm); + zip64_end_of_central_directory_record eocdr64; + eocdr64.size_of_eocd64_m12 = sizeof(eocdr64) - 8; + eocdr64.version_made_by = 0; + eocdr64.version_min_required = 0; + eocdr64.disk_number = 0; + eocdr64.start_disk = 0; + eocdr64.cd_records = filemetas.size(); + eocdr64.total_cd_records = filemetas.size(); + eocdr64.cd_size = offset2 - offset; + eocdr64.cd_offset = offset; - return 0; -} + fwrite((char*)&eocdr64, sizeof(eocdr64), 1, fp); + } -int StoreZipWriter::close() { - if (!fp) return 0; + { + uint32_t signature = 0x07064b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); - int offset = ftell(fp); + zip64_end_of_central_directory_locator eocdl64; + eocdl64.eocdr64_disk_number = 0; + eocdl64.eocdr64_offset = offset2; + eocdl64.disk_count = 1; - for (const StoreZipMeta& szm : filemetas) { - uint32_t signature = 0x02014b50; - fwrite((char*)&signature, sizeof(signature), 1, fp); + fwrite((char*)&eocdl64, sizeof(eocdl64), 1, fp); + } - central_directory_file_header cdfh; - cdfh.version_made = 0; - cdfh.version = 0; - cdfh.flag = 0; - cdfh.compression = 0; - cdfh.last_modify_time = 0; - cdfh.last_modify_date = 0; - cdfh.crc32 = szm.crc32; - cdfh.compressed_size = szm.size; - cdfh.uncompressed_size = szm.size; - cdfh.file_name_length = szm.name.size(); - cdfh.extra_field_length = 0; - cdfh.file_comment_length = 0; - cdfh.start_disk = 0; - cdfh.internal_file_attrs = 0; - cdfh.external_file_attrs = 0; - cdfh.lfh_offset = szm.lfh_offset; - - fwrite((char*)&cdfh, sizeof(cdfh), 1, fp); - - fwrite((char*)szm.name.c_str(), szm.name.size(), 1, fp); - } - - int offset2 = ftell(fp); - - { - uint32_t signature = 0x06054b50; - fwrite((char*)&signature, sizeof(signature), 1, fp); + { + uint32_t signature = 0x06054b50; + fwrite((char*)&signature, sizeof(signature), 1, fp); - end_of_central_directory_record eocdr; - eocdr.disk_number = 0; - eocdr.start_disk = 0; - eocdr.cd_records = filemetas.size(); - eocdr.total_cd_records = filemetas.size(); - eocdr.cd_size = offset2 - offset; - eocdr.cd_offset = offset; - eocdr.comment_length = 0; + end_of_central_directory_record eocdr; + eocdr.disk_number = 0xffff; + eocdr.start_disk = 0xffff; + eocdr.cd_records = 0xffff; + eocdr.total_cd_records = 0xffff; + eocdr.cd_size = 0xffffffff; + eocdr.cd_offset = 0xffffffff; + eocdr.comment_length = 0; - fwrite((char*)&eocdr, sizeof(eocdr), 1, fp); - } + fwrite((char*)&eocdr, sizeof(eocdr), 1, fp); + } - fclose(fp); - fp = 0; + fclose(fp); + fp = 0; - return 0; + return 0; } -} // namespace pnnx +} // namespace pnnx #if 0 int main() { - StoreZipReader sz; + using namespace pnnx; - sz.open("test.zip"); + { + uint64_t len = 1*1024*1024*1024; + // uint64_t len = 1*1024*1024; + char* data1g = new char[len]; - std::vector data1; - sz.read_file("pnnx2.py", data1); + StoreZipWriter szw; - std::vector data2; - sz.read_file("pnnx2.param", data2); + szw.open("szw.zip"); - sz.close(); + szw.write_file("a.py", data1g, len); + szw.write_file("b.param", data1g, 44); + szw.write_file("c.bin", data1g, len); + szw.write_file("d.txt", data1g, len); + szw.write_file("e.jpg", data1g, len); + szw.write_file("f.png", data1g, len); + szw.close(); - StoreZipWriter szw; + delete[] data1g; + } + + { + StoreZipReader sz; - szw.open("szw.zip"); + sz.open("szw.zip"); - szw.write_file("a.py", data1); - szw.write_file("zzzz.param", data2); + std::vector names = sz.get_names(); - szw.close(); + for (size_t i = 0; i < names.size(); i++) + { + uint64_t size = sz.get_file_size(names[i]); + fprintf(stderr, "%s %lu\n", names[i].c_str(), size); + } + + sz.close(); + } return 0; }