From e35823fcc481d7710702685bb105f6279de68d7d Mon Sep 17 00:00:00 2001 From: Balyshev Artem <43214667+BalyshevArtem@users.noreply.github.com> Date: Wed, 20 Sep 2023 12:00:22 +0300 Subject: [PATCH] [onert-micro] Introduce TrainingDriver (#11558) This commit introduces TrainingDriver module for onert-micro. ONE-DCO-1.0-Signed-off-by: Artem Balyshev Co-authored-by: Artem Balyshev --- onert-micro/CMakeLists.txt | 9 + onert-micro/eval-driver/CMakeLists.txt | 20 ++ onert-micro/eval-driver/TrainingDriver.cpp | 241 ++++++++++++++++++++ onert-micro/luci-interpreter/CMakeLists.txt | 4 + 4 files changed, 274 insertions(+) create mode 100644 onert-micro/eval-driver/TrainingDriver.cpp diff --git a/onert-micro/CMakeLists.txt b/onert-micro/CMakeLists.txt index 416281d0f02..dca93d480a9 100644 --- a/onert-micro/CMakeLists.txt +++ b/onert-micro/CMakeLists.txt @@ -121,6 +121,12 @@ if (ENABLE_ONERT_MICRO_TEST) list(APPEND CMAKE_ARM_OPTIONS "-DENABLE_TEST=ON") endif () +if (ENABLE_ONERT_MICRO_TRAINING) + message(STATUS "ENABLE TRAINING PART") + add_definitions(-DENABLE_TRAINING) + list(APPEND CMAKE_ARM_OPTIONS "-DENABLE_TRAINING=ON") +endif () + if (DIS_QUANT) message(STATUS "ONERT-MICRO will not use part for QUANTIZED models") add_definitions(-DDIS_QUANT) @@ -195,6 +201,9 @@ add_custom_target(luci_interpreter_micro_arm DEPENDS "${MICRO_ARM_BINARY}") add_subdirectory(eval-driver) +# Should be after add_subdirectory +unset(ENABLE_ONERT_MICRO_TRAINING CACHE) + if (NOT DEFINED BUILD_TEST) return() endif () diff --git a/onert-micro/eval-driver/CMakeLists.txt b/onert-micro/eval-driver/CMakeLists.txt index 2a1b73ad65c..8b29d968bc6 100644 --- a/onert-micro/eval-driver/CMakeLists.txt +++ b/onert-micro/eval-driver/CMakeLists.txt @@ -11,3 +11,23 @@ target_include_directories(onert_micro_eval_driver PUBLIC "${CMAKE_CURRENT_SOURC target_link_libraries(onert_micro_eval_driver PUBLIC luci_interpreter_micro) install(TARGETS onert_micro_eval_driver DESTINATION bin) + +message(STATUS "DONE eval driver") + +if(NOT ENABLE_ONERT_MICRO_TRAINING) + return() +endif(NOT ENABLE_ONERT_MICRO_TRAINING) + +set(SRCS_EVAL_TRAINING_TESTER TrainingDriver.cpp) + +add_executable(onert_micro_training_eval_driver ${SRCS_EVAL_TRAINING_TESTER}) + +# This variable is needed to separate standalone interpreter libraries from the libraries used in driver +set(READER_SUFFIX "_driver") + +target_include_directories(onert_micro_training_eval_driver PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/luci-interpreter/include") +target_link_libraries(onert_micro_training_eval_driver PUBLIC luci_interpreter_micro) + +install(TARGETS onert_micro_training_eval_driver DESTINATION bin) + +message(STATUS "DONE training eval driver") diff --git a/onert-micro/eval-driver/TrainingDriver.cpp b/onert-micro/eval-driver/TrainingDriver.cpp new file mode 100644 index 00000000000..2625ec8924b --- /dev/null +++ b/onert-micro/eval-driver/TrainingDriver.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace +{ + +using DataBuffer = std::vector; + +void readDataFromFile(const std::string &filename, char *data, size_t data_size) +{ + std::ifstream fs(filename, std::ifstream::binary); + if (fs.fail()) + throw std::runtime_error("Cannot open file \"" + filename + "\".\n"); + if (fs.read(data, data_size).fail()) + throw std::runtime_error("Failed to read data from file \"" + filename + "\".\n"); +} + +void writeDataToFile(const std::string &filename, const char *data, size_t data_size) +{ + std::ofstream fs(filename, std::ofstream::binary); + if (fs.fail()) + throw std::runtime_error("Cannot open file \"" + filename + "\".\n"); + if (fs.write(data, data_size).fail()) + { + throw std::runtime_error("Failed to write data to file \"" + filename + "\".\n"); + } +} + +} // namespace + +/* + * @brief EvalDriver main + * + * Driver for testing luci-inerpreter + * + */ +int entry(int argc, char **argv) +{ + if (argc != 8) + { + std::cerr + << "Usage: " << argv[0] + << " " + " num_of_train_smpl " + "num_of_test_smpl\n"; + return EXIT_FAILURE; + } + + const char *filename = argv[1]; + const char *input_train_data_path = argv[2]; + const char *input_label_train_data_path = argv[3]; + const char *input_test_data_path = argv[4]; + const char *input_label_test_data_path = argv[5]; + const int32_t num_train_data_samples = atoi(argv[6]); + const int32_t num_test_data_samples = atoi(argv[7]); + + std::ifstream file(filename, std::ios::binary | std::ios::in); + if (!file.good()) + { + std::string errmsg = "Failed to open file"; + throw std::runtime_error(errmsg.c_str()); + } + + file.seekg(0, std::ios::end); + auto fileSize = file.tellg(); + file.seekg(0, std::ios::beg); + + // reserve capacity + DataBuffer model_data(fileSize); + + // read the data + file.read(model_data.data(), fileSize); + if (file.fail()) + { + std::string errmsg = "Failed to read file"; + throw std::runtime_error(errmsg.c_str()); + } + + // Create interpreter. + luci_interpreter::Interpreter interpreter(model_data.data(), true); + + luci_interpreter::training::TrainingSettings settings; + settings.learning_rate = 0.0001; + settings.number_of_epochs = 100; + settings.batch_size = 1; + + const auto input_size = interpreter.getInputDataSizeByIndex(0); + const auto output_size = interpreter.getOutputDataSizeByIndex(0); + + char *train_data = new char[input_size * num_train_data_samples]; + char *label_train_data = new char[output_size * num_train_data_samples]; + + char *test_data = new char[input_size * num_test_data_samples]; + char *label_test_data = new char[output_size * num_test_data_samples]; + + readDataFromFile(input_train_data_path, train_data, input_size * num_train_data_samples); + readDataFromFile(input_test_data_path, test_data, input_size * num_test_data_samples); + + readDataFromFile(input_label_train_data_path, label_train_data, + output_size * num_train_data_samples); + readDataFromFile(input_label_test_data_path, label_test_data, + output_size * num_test_data_samples); + + luci_interpreter::training::TrainingOnertMicro onert_micro_training(&interpreter, settings); + onert_micro_training.enableTrainingMode(); + onert_micro_training.train(num_train_data_samples, reinterpret_cast(train_data), + reinterpret_cast(label_train_data)); + onert_micro_training.disableTrainingMode(); + + auto size = interpreter.getInputDataSizeByIndex(0); + auto train_data_u8 = reinterpret_cast(train_data); + + printf("Run train dataset:\n"); + for (int i = 0; i < num_train_data_samples; ++i) + { + auto input_data = reinterpret_cast(interpreter.allocateInputTensor(0)); + + std::memcpy(input_data, train_data_u8, size); + + interpreter.interpret(); + auto data = reinterpret_cast(interpreter.readOutputTensor(0)); + + printf("Sample № = %d\n", i); + for (int j = 0; j < output_size / sizeof(float); ++j) + { + printf("j = %d: predicted_result = %f, correct_result = %f\n", j, data[j], + reinterpret_cast(label_train_data)[j + i * output_size / sizeof(float)]); + } + printf("\n"); + train_data_u8 += size; + } + + auto test_data_u8 = reinterpret_cast(test_data); + + printf("Run test dataset:\n"); + for (int i = 0; i < num_test_data_samples; ++i) + { + auto input_data = reinterpret_cast(interpreter.allocateInputTensor(0)); + + std::memcpy(input_data, test_data_u8, size); + + interpreter.interpret(); + auto data = reinterpret_cast(interpreter.readOutputTensor(0)); + + printf("Sample № = %d\n", i); + for (int j = 0; j < output_size / sizeof(float); ++j) + { + printf("j = %d: predicted_result = %f, correct_result = %f\n", j, data[j], + reinterpret_cast(label_test_data)[j + i * output_size / sizeof(float)]); + } + printf("\n"); + test_data_u8 += size; + } + + float mse_result = 0.0f; + + settings.metric = luci_interpreter::training::MSE; + onert_micro_training.test(num_train_data_samples, reinterpret_cast(train_data), + reinterpret_cast(label_train_data), + reinterpret_cast(&mse_result)); + + float mae_result = 0.0f; + + settings.metric = luci_interpreter::training::MAE; + onert_micro_training.test(num_train_data_samples, reinterpret_cast(train_data), + reinterpret_cast(label_train_data), + reinterpret_cast(&mae_result)); + + printf("MSE_ERROR TRAIN = %f\n", mse_result); + + printf("MAE_ERROR TRAIN = %f\n", mae_result); + + mse_result = 0.0f; + + settings.metric = luci_interpreter::training::MSE; + onert_micro_training.test(num_test_data_samples, reinterpret_cast(test_data), + reinterpret_cast(label_test_data), + reinterpret_cast(&mse_result)); + + mae_result = 0.0f; + + settings.metric = luci_interpreter::training::MAE; + onert_micro_training.test(num_test_data_samples, reinterpret_cast(test_data), + reinterpret_cast(label_test_data), + reinterpret_cast(&mae_result)); + + printf("MSE_ERROR TEST = %f\n", mse_result); + + printf("MAE_ERROR TEST = %f\n", mae_result); + + return EXIT_SUCCESS; +} + +int entry(int argc, char **argv); + +#ifdef NDEBUG +int main(int argc, char **argv) +{ + try + { + return entry(argc, argv); + } + catch (const std::exception &e) + { + std::cerr << "ERROR: " << e.what() << std::endl; + } + + return 255; +} +#else // NDEBUG +int main(int argc, char **argv) +{ + // NOTE main does not catch internal exceptions for debug build to make it easy to + // check the stacktrace with a debugger + return entry(argc, argv); +} +#endif // !NDEBUG diff --git a/onert-micro/luci-interpreter/CMakeLists.txt b/onert-micro/luci-interpreter/CMakeLists.txt index 1bdfa493f98..cfb865cb781 100644 --- a/onert-micro/luci-interpreter/CMakeLists.txt +++ b/onert-micro/luci-interpreter/CMakeLists.txt @@ -25,6 +25,10 @@ if (DIS_FLOAT) add_definitions(-DDIS_FLOAT) endif() +if (ENABLE_TRAINING) + add_definitions(-DENABLE_TRAINING) +endif() + add_compile_options(-fno-exceptions) add_compile_options(-Os)