Skip to content

Commit

Permalink
[onert] Reset compile result after reload (#13314)
Browse files Browse the repository at this point in the history
This commit updates API implementation to reset compile result (`_execution`, `_compiler_artifact`) after model reload.
For model load, it introduce new private method.

ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <[email protected]>
  • Loading branch information
hseok-oh authored Jun 28, 2024
1 parent 768185d commit db62e19
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 26 deletions.
48 changes: 22 additions & 26 deletions runtime/onert/api/nnfw/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,20 +326,13 @@ NNFW_STATUS nnfw_session::load_model_from_modelfile(const char *model_file_path)
std::string model_type = filename.substr(dotidx + 1); // + 1 to exclude dot
try
{
auto model = loadModel(filename, model_type);
if (model == nullptr)
return NNFW_STATUS_ERROR;
_model_path = std::string(model_file_path);
_nnpkg = std::make_shared<onert::ir::NNPkg>(std::move(model));
_train_info = loadTrainingInfo(_nnpkg->primary_model());
_state = State::MODEL_LOADED;
return loadModelFile(filename, model_type);
}
catch (const std::exception &e)
{
std::cerr << "Error during model loading : " << e.what() << std::endl;
return NNFW_STATUS_ERROR;
}
return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::load_model_from_nnpackage(const char *package_dir)
Expand Down Expand Up @@ -979,6 +972,23 @@ uint32_t nnfw_session::getOutputSize()
return _compiler_artifact->_executors->outputSize();
}

NNFW_STATUS nnfw_session::loadModelFile(const std::string &model_file_path,
const std::string &model_type)
{
auto model = loadModel(model_file_path, model_type);
if (model == nullptr)
return NNFW_STATUS_ERROR;

_nnpkg = std::make_shared<onert::ir::NNPkg>(std::move(model));
_model_path = model_file_path;
_compiler_artifact.reset();
_execution.reset();
_train_info = loadTrainingInfo(_nnpkg->primary_model());
_state = State::MODEL_LOADED;

return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::get_config(const char *key, char *value, size_t value_size)
{
if (!isStateModelLoaded())
Expand Down Expand Up @@ -1790,20 +1800,13 @@ NNFW_STATUS nnfw_session::quantize()

// Replace model
// TODO Support buffer replace, not file reload
auto model = loadModel(_quant_manager->exportModelPath(), "circle");
if (model == nullptr)
return NNFW_STATUS_ERROR;
_nnpkg->replaceModel(std::move(model));
_state = State::MODEL_LOADED;
_model_path = _quant_manager->exportModelPath();
return loadModelFile(_quant_manager->exportModelPath(), "circle");
}
catch (const std::exception &e)
{
std::cerr << "Error during nnfw_session::quantize : " << e.what() << std::endl;
return NNFW_STATUS_ERROR;
}

return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::set_codegen_model_path(const char *path)
Expand Down Expand Up @@ -1897,23 +1900,16 @@ NNFW_STATUS nnfw_session::codegen(const char *target, NNFW_CODEGEN_PREF pref)
}

std::string model_type = export_model_path.substr(dotidx + 1); // + 1 to exclude dot
auto model = loadModel(export_model_path, model_type);
if (model == nullptr)
return NNFW_STATUS_ERROR;

_nnpkg->replaceModel(std::move(model));
_state = State::MODEL_LOADED;
_model_path = export_model_path;
_compiler_artifact.reset();
_execution.reset();
// Replace model
// TODO Support buffer replace, not file reload
return loadModelFile(export_model_path, model_type);
}
catch (const std::exception &e)
{
std::cerr << "Error during nnfw_session::compile : " << e.what() << std::endl;
return NNFW_STATUS_ERROR;
}

return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::set_prepare_config(const NNFW_PREPARE_CONFIG key, const char *)
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/api/nnfw/src/nnfw_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ struct nnfw_session
const onert::ir::IGraph *primary_subgraph();
uint32_t getInputSize();
uint32_t getOutputSize();
NNFW_STATUS loadModelFile(const std::string &model_file_path, const std::string &model_type);

bool isStateInitialized();
bool isStateModelLoaded();
Expand Down

0 comments on commit db62e19

Please sign in to comment.