From ceb6706e1a9e2da48597d23ad9513d010501c370 Mon Sep 17 00:00:00 2001 From: SlavikMIPT Date: Thu, 14 Sep 2023 16:35:58 +0300 Subject: [PATCH] [onert-micro] Introduce Trainable weight storage (#11398) This commit introduces Trainable weight storage class. ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../luci_interpreter/TrainableWeightStorage.h | 61 +++++++++ .../luci-interpreter/src/core/CMakeLists.txt | 3 +- .../src/core/TrainableWeightStorage.cpp | 117 ++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 onert-micro/luci-interpreter/include/luci_interpreter/TrainableWeightStorage.h create mode 100644 onert-micro/luci-interpreter/src/core/TrainableWeightStorage.cpp diff --git a/onert-micro/luci-interpreter/include/luci_interpreter/TrainableWeightStorage.h b/onert-micro/luci-interpreter/include/luci_interpreter/TrainableWeightStorage.h new file mode 100644 index 00000000000..9b4bd79196f --- /dev/null +++ b/onert-micro/luci-interpreter/include/luci_interpreter/TrainableWeightStorage.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_TRAINING + +#ifndef LUCI_INTERPRETER_CORE_TRAINABLE_WEIGHT_STORAGE_H +#define LUCI_INTERPRETER_CORE_TRAINABLE_WEIGHT_STORAGE_H + +#include "luci_interpreter/TrainingSettings.h" +#include "luci_interpreter/core/reader/CircleMicroReader.h" +#include "memory_managers/SimpleMemoryManager.h" + +#include + +namespace luci_interpreter +{ +namespace training +{ + +class TrainableWeightStorage +{ +public: + TrainableWeightStorage() = default; + +public: + Status getTrainWeightDataByTensor(const circle::Tensor *tensor, uint8_t **result_data); + + Status clearAllTrainableWeights(); + + training::Status fillTrainableWeightsStorage(const CircleReader *reader, + SimpleMemoryManager *memory_manager, + uint32_t number_of_last_trainable_layers); + +private: + Status createTrainableWeightForTensor(const circle::Tensor *tensor, + SimpleMemoryManager *memoryManager, + const uint8_t *const_data); + +private: + std::unordered_map _tensor_to_data; +}; + +} // namespace training +} // namespace luci_interpreter + +#endif // LUCI_INTERPRETER_CORE_TRAINABLE_WEIGHT_STORAGE_H + +#endif // ENABLE_TRAINING diff --git a/onert-micro/luci-interpreter/src/core/CMakeLists.txt b/onert-micro/luci-interpreter/src/core/CMakeLists.txt index 48d92aa372f..b0e101d54c4 100644 --- a/onert-micro/luci-interpreter/src/core/CMakeLists.txt +++ b/onert-micro/luci-interpreter/src/core/CMakeLists.txt @@ -3,7 +3,8 @@ set(SOURCES "${LUCI_INTERPRETER_INCLUDE_DIR}/luci_interpreter/core/Tensor.h" RuntimeGraph.h RuntimeGraph.cpp - RuntimeModule.h) + RuntimeModule.h + TrainableWeightStorage.cpp) add_library(${LUCI_INTERPRETER_CORE} STATIC ${SOURCES}) if (NOT NNCC_LIBRARY_NO_PIC) diff --git a/onert-micro/luci-interpreter/src/core/TrainableWeightStorage.cpp b/onert-micro/luci-interpreter/src/core/TrainableWeightStorage.cpp new file mode 100644 index 00000000000..dd21298ec66 --- /dev/null +++ b/onert-micro/luci-interpreter/src/core/TrainableWeightStorage.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_TRAINING + +#include "luci_interpreter/core/TrainableWeightStorage.h" + +namespace luci_interpreter +{ +namespace training +{ + +Status TrainableWeightStorage::createTrainableWeightForTensor(const circle::Tensor *tensor, + SimpleMemoryManager *memoryManager, + const uint8_t *const_data) +{ + assert(_tensor_to_data.count(tensor) == 0 && "Double training weight"); + + if (_tensor_to_data.count(tensor) != 0) + { + return Error; + } + + uint8_t *allocated_data = memoryManager->allocate_memory(tensor); + + std::memcpy(allocated_data, const_data, + size(Tensor::element_type(tensor)) * Tensor::num_elements(tensor)); + + _tensor_to_data[tensor] = allocated_data; + + return Ok; +} + +training::Status +TrainableWeightStorage::fillTrainableWeightsStorage(const CircleReader *reader, + SimpleMemoryManager *memory_manager, + uint32_t number_of_last_trainable_layers) +{ + const auto operators_size = reader->operators().size(); + const auto operators = reader->operators(); + + const uint32_t first_trainable_layer_pos = operators_size - number_of_last_trainable_layers; + + for (uint32_t i = first_trainable_layer_pos; i < operators_size; ++i) + { + const auto op = operators.at(i); + assert(op != nullptr); + + const auto *op_inputs = op->inputs(); + + for (const int32_t input_idx : *op_inputs) + { + if (input_idx == -1) + continue; + const circle::Tensor *tensor = reader->tensors()[input_idx]; + + if (_tensor_to_data.count(tensor) > 0) + continue; + + const auto tensor_data = reader->buffers()[tensor->buffer()]->data(); + if (tensor_data != nullptr) + { + if (createTrainableWeightForTensor(tensor, memory_manager, tensor_data->data()) == + training::Error) + return training::Error; + } + } + } + return training::Ok; +} + +Status TrainableWeightStorage::clearAllTrainableWeights() +{ + for (const auto &pair : _tensor_to_data) + { + delete[] pair.second; + } + + _tensor_to_data.clear(); + return Ok; +} + +Status TrainableWeightStorage::getTrainWeightDataByTensor(const circle::Tensor *tensor, + uint8_t **result_data) +{ + assert(tensor != nullptr); // CALLER SIDE + + auto it = _tensor_to_data.find(tensor); + + if (it == _tensor_to_data.end()) + { + result_data = nullptr; + return Ok; + } + + *result_data = it->second; + + return Ok; +} + +} // namespace training +} // namespace luci_interpreter + +#endif // ENABLE_TRAINING