Skip to content

Commit

Permalink
[onert] Move OptimizerCode and OptimizerInfo into ir (Samsung#11511)
Browse files Browse the repository at this point in the history
This commit moves OptimizerCode and OptimizerInfo into ir.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Sep 18, 2023
1 parent 2537288 commit 5e937f2
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 30 deletions.
6 changes: 3 additions & 3 deletions runtime/onert/api/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1197,13 +1197,13 @@ NNFW_STATUS nnfw_session::train_prepare(const nnfw_train_info *info)

auto convertOptType = [](const int &type) {
if (type == NNFW_TRAIN_OPTIMIZER_SGD)
return onert::exec::train::optimizer::OptimizerCode::SGD;
return onert::ir::train::OptimizerCode::SGD;
else if (type == NNFW_TRAIN_OPTIMIZER_ADAM)
return onert::exec::train::optimizer::OptimizerCode::Adam;
return onert::ir::train::OptimizerCode::Adam;
else
throw std::runtime_error("not supported optimizer type");
};
onert::compiler::train::OptimizerInfo opt_info;
onert::ir::train::OptimizerInfo opt_info;
opt_info.learning_rate = tinfo.learning_rate;
opt_info.optim_code = convertOptType(tinfo.opt);

Expand Down
19 changes: 8 additions & 11 deletions runtime/onert/core/include/compiler/train/TrainingInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
#define __ONERT_COMPILER_TRAIN_TRAINING_INFO_H__

#include "ir/Index.h"
#include "exec/train/optimizer/OptimizerCode.h"
#include "ir/operation/Loss.h"
#include "ir/train/OptimizerCode.h"
#include "ir/train/OptimizerInfo.h"

namespace onert
{
Expand All @@ -34,13 +35,6 @@ struct LossInfo
// TODO Add members for loss
};

struct OptimizerInfo
{
exec::train::optimizer::OptimizerCode optim_code;
float learning_rate;
// TODO Add properties
};

class TrainingInfo
{
public:
Expand All @@ -55,12 +49,15 @@ class TrainingInfo
void setBatchSize(const uint32_t batch_size) { _batch_size = batch_size; }
const LossInfo &lossInfo() const { return _loss_info; }
void setLossInfo(const LossInfo &loss_info) { _loss_info = loss_info; }
const OptimizerInfo &optimizerInfo() const { return _optimizer_info; }
void setOptimizerInfo(const OptimizerInfo &optimizer_info) { _optimizer_info = optimizer_info; }
const ir::train::OptimizerInfo &optimizerInfo() const { return _optimizer_info; }
void setOptimizerInfo(const ir::train::OptimizerInfo &optimizer_info)
{
_optimizer_info = optimizer_info;
}

private:
LossInfo _loss_info;
OptimizerInfo _optimizer_info;
ir::train::OptimizerInfo _optimizer_info;
uint32_t _batch_size;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
* limitations under the License.
*/

#ifndef __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_CODE_H__
#define __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_CODE_H__
#ifndef __ONERT_IR_TRAIN_OPTIMIZER_CODE_H__
#define __ONERT_IR_TRAIN_OPTIMIZER_CODE_H__

#include <functional>
#include <stdint.h>
#include <string>

namespace onert
{
namespace exec
namespace ir
{
namespace train
{
namespace optimizer
{

enum class OptimizerCode
{
Expand All @@ -45,9 +43,8 @@ enum class OptimizerCode
*/
std::string toString(OptimizerCode opcode);

} // namespace optimizer
} // namespace train
} // namespace exec
} // namespace ir
} // namespace onert

#endif // __ONERT_EXEC_TRAIN_OPTIMIZER_OPTIMIZER_CODE_H__
#endif // __ONERT_IR_TRAIN_OPTIMIZER_CODE_H__
40 changes: 40 additions & 0 deletions runtime/onert/core/include/ir/train/OptimizerInfo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_IR_TRAIN_OPTIMIZER_INFO_H__
#define __ONERT_IR_TRAIN_OPTIMIZER_INFO_H__

#include "OptimizerCode.h"

namespace onert
{
namespace ir
{
namespace train
{

struct OptimizerInfo
{
OptimizerCode optim_code;
float learning_rate;
// TODO Add properties
};

} // namespace train
} // namespace ir
} // namespace onert

#endif // __ONERT_IR_TRAIN_OPTIMIZER_INFO_H__
4 changes: 2 additions & 2 deletions runtime/onert/core/src/compiler/train/TrainingCompiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ std::shared_ptr<CompilerArtifact> TrainingCompiler::compile(void)
// TODO Set properties of optimizer
std::shared_ptr<exec::train::optimizer::Optimizer> optimizer;
const auto &optim_info = _training_info.optimizerInfo();
if (optim_info.optim_code == exec::train::optimizer::OptimizerCode::SGD)
if (optim_info.optim_code == ir::train::OptimizerCode::SGD)
optimizer = std::make_shared<exec::train::optimizer::SGD>(optim_info.learning_rate);
else
throw std::runtime_error("Invalid optimizer type, " +
exec::train::optimizer::toString(optim_info.optim_code));
ir::train::toString(optim_info.optim_code));

/*************************************************************
* Backend independent analysis & optimization phase finished
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@
* limitations under the License.
*/

#include "exec/train/optimizer/OptimizerCode.h"
#include "ir/train/OptimizerCode.h"

#include <unordered_map>

namespace onert
{
namespace exec
namespace ir
{
namespace train
{
namespace optimizer
{

std::string toString(OptimizerCode code)
{
Expand All @@ -36,7 +34,6 @@ std::string toString(OptimizerCode code)
return map.at(code);
}

} // namespace optimizer
} // namespace train
} // namespace exec
} // namespace ir
} // namespace onert

0 comments on commit 5e937f2

Please sign in to comment.