From 23cb3332ee539620ca2c35fc688337d8edbcd903 Mon Sep 17 00:00:00 2001 From: Artem Balyshev Date: Mon, 18 Sep 2023 18:43:27 +0300 Subject: [PATCH] [onert-micro] Add training into RuntimeModule This commit adds training into RuntimeModule in onert-micro. ONE-DCO-1.0-Signed-off-by: Artem Balyshev --- .../luci-interpreter/src/core/RuntimeModule.h | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/onert-micro/luci-interpreter/src/core/RuntimeModule.h b/onert-micro/luci-interpreter/src/core/RuntimeModule.h index d42698277eb..ceb9ef160fa 100644 --- a/onert-micro/luci-interpreter/src/core/RuntimeModule.h +++ b/onert-micro/luci-interpreter/src/core/RuntimeModule.h @@ -20,6 +20,12 @@ #include "core/RuntimeGraph.h" #include "luci_interpreter/core/reader/CircleMicroReader.h" +#ifdef ENABLE_TRAINING +#include "luci_interpreter/TrainableWeightStorage.h" +#include "TrainingGraph.h" +#include "core/RuntimeGraph.h" +#endif // ENABLE_TRAINING + #include #include @@ -34,11 +40,24 @@ using BaseRuntimeGraph = RuntimeGraph; using MemoryManager = SimpleMemoryManager; #endif // USE_STATIC_ALLOC +#ifdef ENABLE_TRAINING +namespace training +{ + +class TrainingModule; + +} // namespace training +#endif // ENABLE_TRAINING + class RuntimeModule { public: RuntimeModule() = default; +#ifdef ENABLE_TRAINING + friend class training::TrainingModule; +#endif // ENABLE_TRAINING + void addGraph(MemoryManager *memory_manager) { _graphs.emplace_back(memory_manager, &_circle_reader, this, @@ -59,6 +78,10 @@ class RuntimeModule std::vector _graphs; CircleReader _circle_reader; + +#ifdef ENABLE_TRAINING + std::unique_ptr _storage; +#endif // ENABLE_TRAINING }; } // namespace luci_interpreter