Skip to content

Commit

Permalink
[onert-micro] Add training into RuntimeModule
Browse files Browse the repository at this point in the history
This commit adds training into RuntimeModule in onert-micro.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Sep 18, 2023
1 parent 4148adb commit b2044b5
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions onert-micro/luci-interpreter/src/core/RuntimeModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
#include "core/RuntimeGraph.h"
#include "luci_interpreter/core/reader/CircleMicroReader.h"

#ifdef ENABLE_TRAINING
#include "luci_interpreter/TrainableWeightStorage.h"
#include "TrainingGraph.h"
#include "core/RuntimeGraph.h"
#endif // ENABLE_TRAINING


#include <memory>
#include <vector>

Expand All @@ -34,11 +41,24 @@ using BaseRuntimeGraph = RuntimeGraph;
using MemoryManager = SimpleMemoryManager;
#endif // USE_STATIC_ALLOC

#ifdef ENABLE_TRAINING
namespace training
{

class TrainingModule;

} // namespace training
#endif // ENABLE_TRAINING

class RuntimeModule
{
public:
RuntimeModule() = default;

#ifdef ENABLE_TRAINING
friend class training::TrainingModule;
#endif // ENABLE_TRAINING

void addGraph(MemoryManager *memory_manager)
{
_graphs.emplace_back(memory_manager, &_circle_reader, this,
Expand All @@ -59,6 +79,10 @@ class RuntimeModule
std::vector<BaseRuntimeGraph> _graphs;

CircleReader _circle_reader;

#ifdef ENABLE_TRAINING
std::unique_ptr<training::TrainableWeightStorage> _storage;
#endif // ENABLE_TRAINING
};

} // namespace luci_interpreter
Expand Down

0 comments on commit b2044b5

Please sign in to comment.