diff --git a/compiler/circle-mpqsolver/src/MPQSolver.cpp b/compiler/circle-mpqsolver/src/MPQSolver.cpp index 10cfbb65fb0..eea82efef0f 100644 --- a/compiler/circle-mpqsolver/src/MPQSolver.cpp +++ b/compiler/circle-mpqsolver/src/MPQSolver.cpp @@ -16,6 +16,9 @@ #include "MPQSolver.h" +#include +#include + using namespace mpqsolver; MPQSolver::MPQSolver(const std::string &input_data_path, float qerror_ratio, @@ -23,9 +26,23 @@ MPQSolver::MPQSolver(const std::string &input_data_path, float qerror_ratio, : _input_data_path(input_data_path), _qerror_ratio(qerror_ratio), _input_quantization(input_quantization), _output_quantization(output_quantization) { + _quantizer = std::make_unique(_input_quantization, _output_quantization); } void MPQSolver::set_save_intermediate(const std::string &save_path) { _hooks = std::make_unique(save_path); } + +std::unique_ptr MPQSolver::read_module(const std::string &path) +{ + luci::ImporterEx importerex; + auto module = importerex.importVerifyModule(path); + if (module.get() == nullptr) + { + std::cerr << "ERROR: Failed to load " << path << std::endl; + return nullptr; + } + + return module; +} diff --git a/compiler/circle-mpqsolver/src/MPQSolver.h b/compiler/circle-mpqsolver/src/MPQSolver.h index 6c5d25dad78..6718be2ea18 100644 --- a/compiler/circle-mpqsolver/src/MPQSolver.h +++ b/compiler/circle-mpqsolver/src/MPQSolver.h @@ -17,8 +17,11 @@ #ifndef __MPQSOLVER_MPQSOLEVR_SOLVER_H__ #define __MPQSOLVER_MPQSOLEVR_SOLVER_H__ +#include "core/Quantizer.h" #include +#include + #include #include @@ -47,10 +50,14 @@ class MPQSolver */ void set_save_intermediate(const std::string &save_path); +protected: + std::unique_ptr read_module(const std::string &path); + protected: std::string _input_data_path; std::string _input_quantization; std::string _output_quantization; + std::unique_ptr _quantizer; float _qerror_ratio = 0.f; // quantization error ratio std::unique_ptr _hooks; }; diff --git a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp index 976dac55028..6272947ab2f 100644 --- a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp +++ b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.cpp @@ -72,19 +72,6 @@ bool front_has_higher_error(const NodeDepthType &nodes_depth, const std::string return error_at_input > error_at_output; } -std::unique_ptr read_module(const std::string &path) -{ - luci::ImporterEx importerex; - auto module = importerex.importVerifyModule(path); - if (module.get() == nullptr) - { - std::cerr << "ERROR: Failed to load " << path << std::endl; - return nullptr; - } - - return module; -} - } // namespace BisectionSolver::BisectionSolver(const std::string &input_data_path, float qerror_ratio, @@ -92,7 +79,6 @@ BisectionSolver::BisectionSolver(const std::string &input_data_path, float qerro const std::string &output_quantization) : MPQSolver(input_data_path, qerror_ratio, input_quantization, output_quantization) { - _quantizer = std::make_unique(_input_quantization, _output_quantization); } float BisectionSolver::evaluate(const core::DatasetEvaluator &evaluator, diff --git a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h index 83851c0c853..1b73be98eac 100644 --- a/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h +++ b/compiler/circle-mpqsolver/src/bisection/BisectionSolver.h @@ -17,7 +17,6 @@ #ifndef __MPQSOLVER_BISECTION_SOLVER_H__ #define __MPQSOLVER_BISECTION_SOLVER_H__ -#include #include #include @@ -78,7 +77,6 @@ class BisectionSolver final : public MPQSolver private: float _qerror = 0.f; // quantization error Algorithm _algorithm = Algorithm::ForceQ16Front; - std::unique_ptr _quantizer; std::string _visq_data_path; };