Skip to content

Commit

Permalink
[onert-micro] Introduce TrainingModule class
Browse files Browse the repository at this point in the history
This commit introduces TrainingModule class.

ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov <[email protected]>
ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Vyacheslav Bazhenov committed Aug 31, 2023
1 parent 2cb3063 commit 32f5a86
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 1 deletion.
3 changes: 2 additions & 1 deletion onert-micro/luci-interpreter/src/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ set(SOURCES
"${LUCI_INTERPRETER_INCLUDE_DIR}/luci_interpreter/core/Tensor.h"
RuntimeGraph.h
RuntimeGraph.cpp
RuntimeModule.h)
RuntimeModule.h
TrainingModule.cpp)

add_library(${LUCI_INTERPRETER_CORE} STATIC ${SOURCES})
if (NOT NNCC_LIBRARY_NO_PIC)
Expand Down
78 changes: 78 additions & 0 deletions onert-micro/luci-interpreter/src/core/TrainingModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef ENABLE_TRAINING

#include "TrainingModule.h"

#include <memory>

namespace luci_interpreter
{
namespace training
{

training::Status TrainingModule::enableTrainingMode(training::TrainingSettings &settings,
SimpleMemoryManager *memoryManager)
{
if (_runtime_module->_storage.get() == nullptr)
{
_runtime_module->_storage = std::make_unique<TrainableWeightStorage>();
}

if (_runtime_module->_storage->fillTrainableWeightsStorage(
&_runtime_module->_circle_reader, memoryManager,
settings.number_of_last_trainable_layers) == training::Error)
return training::Error;

_training_graph = std::make_unique<training::TrainingGraph>();

for (auto &graph : _runtime_module->_graphs)
{
graph.setLastTrainingLayersNumber(settings.number_of_last_trainable_layers);
graph.setGradientCalculationStorage(_training_graph->getGradientCalculationStorage());
graph.setTrainingWeightStorage(_runtime_module->_storage.get());
}

return training::Ok;
}

training::Status TrainingModule::disableTrainingMode(bool resetWeights)
{
_training_graph.release();

if (resetWeights)
{
if (_runtime_module->_storage->clearAllTrainableWeights() == training::Error)
return training::Error;
_runtime_module->_storage.release();
}

for (auto &graph : _runtime_module->_graphs)
{
graph.setLastTrainingLayersNumber(0);
graph.setGradientCalculationStorage(nullptr);
if (resetWeights)
graph.setTrainingWeightStorage(nullptr);
}

return training::Ok;
}

} // namespace training
} // namespace luci_interpreter

#endif // ENABLE_TRAINING
69 changes: 69 additions & 0 deletions onert-micro/luci-interpreter/src/core/TrainingModule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef ENABLE_TRAINING

#ifndef LUCI_INTERPRETER_CORE_TRAINING_MODULE_H
#define LUCI_INTERPRETER_CORE_TRAINING_MODULE_H

#include "core/RuntimeModule.h"

#include "luci_interpreter/core/TrainableWeightStorage.h"
#include "TrainingGraph.h"

namespace luci_interpreter
{
namespace training
{

class TrainingModule
{
public:
TrainingModule(RuntimeModule *runtime_module) : _runtime_module(runtime_module)
{
// Do nothing
}

training::Status enableTrainingMode(training::TrainingSettings &settings,
SimpleMemoryManager *memoryManager);

training::Status disableTrainingMode(bool resetWeights);

training::Status computeGradients(const TrainingSettings &settings,
const uint8_t *label_train_data)
{
return _training_graph->computeGradients(settings, _runtime_module->_storage.get(),
&_runtime_module->_circle_reader, label_train_data);
}

training::Status updateWeights(const TrainingSettings &settings)
{
return _training_graph->updateWeights(settings, _runtime_module->_storage.get(),
&_runtime_module->_circle_reader);
}

private:
RuntimeModule *_runtime_module;

std::unique_ptr<training::TrainingGraph> _training_graph;
};

} // namespace training
} // namespace luci_interpreter

#endif // LUCI_INTERPRETER_CORE_TRAINING_MODULE_H

#endif // ENABLE_TRAINING

0 comments on commit 32f5a86

Please sign in to comment.