Skip to content

Commit

Permalink
[onert-micro] Add training into RuntimeGraph (#11548)
Browse files Browse the repository at this point in the history
This commit adds training into RuntimeGraph in onert-micro.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>

Co-authored-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem and Artem Balyshev authored Sep 19, 2023
1 parent 1c0d919 commit 2d63222
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
49 changes: 49 additions & 0 deletions onert-micro/luci-interpreter/src/core/RuntimeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,33 @@ void RuntimeGraph::deallocate(size_t kernel_index)
{
assert(_reader->get_current_subgraph_index() == _subgraph_index);
assert(_is_valid && kernel_index < _dealloc_plan.size());

#ifdef ENABLE_TRAINING
// const uint32_t number_of_trainable_last_layers = _number_of_last_trainable_layers;
const uint32_t last_layer = _number_of_last_trainable_layers > 0
? _reader->operators().size() - _number_of_last_trainable_layers
: 0;
#endif // ENABLE_TRAINING

for (const circle::Tensor *tensor : _dealloc_plan[kernel_index])
{
const auto it = _tensor_to_data.find(tensor);
assert(it != _tensor_to_data.end());

auto *data = _tensor_to_data.at(tensor);

#ifdef ENABLE_TRAINING
if (_number_of_last_trainable_layers > 0 and kernel_index >= last_layer)
{
_gradient_calc_storage->saveDataToTensor(tensor, data);
}
else
{
_memory_manager->release_memory(data);
}
#else
_memory_manager->release_memory(data);
#endif // ENABLE_TRAINING

_tensor_to_data.erase(it);
}
Expand Down Expand Up @@ -374,6 +394,21 @@ uint8_t *RuntimeGraph::getConstDataByTensor(const circle::Tensor *raw_tensor)
if (raw_tensor == nullptr)
return nullptr;

#ifdef ENABLE_TRAINING
assert(_number_of_last_trainable_layers == 0 or
_storage != nullptr && "Storage should not be null here");

if (_storage != nullptr)
{
uint8_t *result = nullptr;
_storage->getTrainWeightDataByTensor(raw_tensor, &result);

if (result != nullptr)
return result;
}

#endif // ENABLE_TRAINING

auto const &buffer = wrap(_reader->buffers()[raw_tensor->buffer()]->data());

return const_cast<uint8_t *>(buffer.data());
Expand Down Expand Up @@ -438,6 +473,20 @@ void RuntimeGraph::execute()

deallocate(i);
}

#ifdef ENABLE_TRAINING
if (_number_of_last_trainable_layers > 0)
{
const auto graph_output = _reader->outputs();

assert(graph_output.size() == 1);

const auto output_tensor = _reader->tensors()[graph_output[0]];
uint8_t *output_data = _tensor_to_data.at(output_tensor);
_gradient_calc_storage->saveDataToTensor(output_tensor, output_data);
_tensor_to_data.erase(output_tensor);
}
#endif // ENABLE_TRAINING
}

} // namespace luci_interpreter
23 changes: 23 additions & 0 deletions onert-micro/luci-interpreter/src/core/RuntimeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
#define LUCI_INTERPRETER_CORE_RUNTIMEGRAPH_H

#include "luci_interpreter/core/Tensor.h"

#ifdef ENABLE_TRAINING
#include "TrainingGraph.h"
#endif // ENABLE_TRAINING

#ifdef USE_STATIC_ALLOC
#include "memory_managers/StaticMemoryManager.h"
#else
Expand Down Expand Up @@ -110,6 +115,18 @@ class RuntimeGraph
return _inplace_op_indexes.find(op) != _inplace_op_indexes.end();
}

#ifdef ENABLE_TRAINING
void setLastTrainingLayersNumber(uint32_t training_number)
{
_number_of_last_trainable_layers = training_number;
}
void setGradientCalculationStorage(training::GradientCalculationStorage *gradient_calc_storage)
{
_gradient_calc_storage = gradient_calc_storage;
}
void setTrainingWeightStorage(training::TrainableWeightStorage *storage) { _storage = storage; }
#endif // ENABLE_TRAINING

#ifndef DIS_DYN_SHAPES
void addDynamicShapeTensor(const circle::Tensor *tensor, luci_interpreter::RuntimeShape &&shapes);

Expand All @@ -133,6 +150,12 @@ class RuntimeGraph

bool _is_valid = false;

#ifdef ENABLE_TRAINING
uint32_t _number_of_last_trainable_layers = 0; // 0 means there is no training
training::GradientCalculationStorage *_gradient_calc_storage = nullptr;
training::TrainableWeightStorage *_storage = nullptr;
#endif // ENABLE_TRAINING

// Tensors that are not used anymore after given op
std::vector<std::vector<const circle::Tensor *>> _alloc_plan;
std::vector<std::vector<const circle::Tensor *>> _dealloc_plan;
Expand Down

0 comments on commit 2d63222

Please sign in to comment.