diff --git a/onert-micro/luci-interpreter/src/core/RuntimeGraph.cpp b/onert-micro/luci-interpreter/src/core/RuntimeGraph.cpp index 917ebbbbfd9..bff871fcbc1 100644 --- a/onert-micro/luci-interpreter/src/core/RuntimeGraph.cpp +++ b/onert-micro/luci-interpreter/src/core/RuntimeGraph.cpp @@ -172,13 +172,33 @@ void RuntimeGraph::deallocate(size_t kernel_index) { assert(_reader->get_current_subgraph_index() == _subgraph_index); assert(_is_valid && kernel_index < _dealloc_plan.size()); + +#ifdef ENABLE_TRAINING + // const uint32_t number_of_trainable_last_layers = _number_of_last_trainable_layers; + const uint32_t last_layer = _number_of_last_trainable_layers > 0 + ? _reader->operators().size() - _number_of_last_trainable_layers + : 0; +#endif // ENABLE_TRAINING + for (const circle::Tensor *tensor : _dealloc_plan[kernel_index]) { const auto it = _tensor_to_data.find(tensor); assert(it != _tensor_to_data.end()); auto *data = _tensor_to_data.at(tensor); + +#ifdef ENABLE_TRAINING + if (_number_of_last_trainable_layers > 0 and kernel_index >= last_layer) + { + _gradient_calc_storage->saveDataToTensor(tensor, data); + } + else + { + _memory_manager->release_memory(data); + } +#else _memory_manager->release_memory(data); +#endif // ENABLE_TRAINING _tensor_to_data.erase(it); } @@ -374,6 +394,21 @@ uint8_t *RuntimeGraph::getConstDataByTensor(const circle::Tensor *raw_tensor) if (raw_tensor == nullptr) return nullptr; +#ifdef ENABLE_TRAINING + assert(_number_of_last_trainable_layers == 0 or + _storage != nullptr && "Storage should not be null here"); + + if (_storage != nullptr) + { + uint8_t *result = nullptr; + _storage->getTrainWeightDataByTensor(raw_tensor, &result); + + if (result != nullptr) + return result; + } + +#endif // ENABLE_TRAINING + auto const &buffer = wrap(_reader->buffers()[raw_tensor->buffer()]->data()); return const_cast(buffer.data()); @@ -438,6 +473,20 @@ void RuntimeGraph::execute() deallocate(i); } + +#ifdef ENABLE_TRAINING + if (_number_of_last_trainable_layers > 0) + { + const auto graph_output = _reader->outputs(); + + assert(graph_output.size() == 1); + + const auto output_tensor = _reader->tensors()[graph_output[0]]; + uint8_t *output_data = _tensor_to_data.at(output_tensor); + _gradient_calc_storage->saveDataToTensor(output_tensor, output_data); + _tensor_to_data.erase(output_tensor); + } +#endif // ENABLE_TRAINING } } // namespace luci_interpreter diff --git a/onert-micro/luci-interpreter/src/core/RuntimeGraph.h b/onert-micro/luci-interpreter/src/core/RuntimeGraph.h index baac0b1b9f8..d75bfc6ed44 100644 --- a/onert-micro/luci-interpreter/src/core/RuntimeGraph.h +++ b/onert-micro/luci-interpreter/src/core/RuntimeGraph.h @@ -18,6 +18,11 @@ #define LUCI_INTERPRETER_CORE_RUNTIMEGRAPH_H #include "luci_interpreter/core/Tensor.h" + +#ifdef ENABLE_TRAINING +#include "TrainingGraph.h" +#endif // ENABLE_TRAINING + #ifdef USE_STATIC_ALLOC #include "memory_managers/StaticMemoryManager.h" #else @@ -110,6 +115,18 @@ class RuntimeGraph return _inplace_op_indexes.find(op) != _inplace_op_indexes.end(); } +#ifdef ENABLE_TRAINING + void setLastTrainingLayersNumber(uint32_t training_number) + { + _number_of_last_trainable_layers = training_number; + } + void setGradientCalculationStorage(training::GradientCalculationStorage *gradient_calc_storage) + { + _gradient_calc_storage = gradient_calc_storage; + } + void setTrainingWeightStorage(training::TrainableWeightStorage *storage) { _storage = storage; } +#endif // ENABLE_TRAINING + #ifndef DIS_DYN_SHAPES void addDynamicShapeTensor(const circle::Tensor *tensor, luci_interpreter::RuntimeShape &&shapes); @@ -133,6 +150,12 @@ class RuntimeGraph bool _is_valid = false; +#ifdef ENABLE_TRAINING + uint32_t _number_of_last_trainable_layers = 0; // 0 means there is no training + training::GradientCalculationStorage *_gradient_calc_storage = nullptr; + training::TrainableWeightStorage *_storage = nullptr; +#endif // ENABLE_TRAINING + // Tensors that are not used anymore after given op std::vector> _alloc_plan; std::vector> _dealloc_plan;