diff --git a/onert-micro/luci-interpreter/src/core/RuntimeModule.h b/onert-micro/luci-interpreter/src/core/RuntimeModule.h index d42698277eb..5f3692129f8 100644 --- a/onert-micro/luci-interpreter/src/core/RuntimeModule.h +++ b/onert-micro/luci-interpreter/src/core/RuntimeModule.h @@ -20,6 +20,13 @@ #include "core/RuntimeGraph.h" #include "luci_interpreter/core/reader/CircleMicroReader.h" +#ifdef ENABLE_TRAINING +#include "luci_interpreter/TrainableWeightStorage.h" +#include "TrainingGraph.h" +#include "core/RuntimeGraph.h" +#endif // ENABLE_TRAINING + + #include #include @@ -34,11 +41,24 @@ using BaseRuntimeGraph = RuntimeGraph; using MemoryManager = SimpleMemoryManager; #endif // USE_STATIC_ALLOC +#ifdef ENABLE_TRAINING +namespace training +{ + +class TrainingModule; + +} // namespace training +#endif // ENABLE_TRAINING + class RuntimeModule { public: RuntimeModule() = default; +#ifdef ENABLE_TRAINING + friend class training::TrainingModule; +#endif // ENABLE_TRAINING + void addGraph(MemoryManager *memory_manager) { _graphs.emplace_back(memory_manager, &_circle_reader, this, @@ -59,6 +79,10 @@ class RuntimeModule std::vector _graphs; CircleReader _circle_reader; + +#ifdef ENABLE_TRAINING + std::unique_ptr _storage; +#endif // ENABLE_TRAINING }; } // namespace luci_interpreter