diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc new file mode 100644 index 00000000000..615b1650c38 --- /dev/null +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "UseDefGenerator.h" + +#include "ir/train/TrainableGraph.h" +#include "ir/train/Index.h" +#include "../verifier/Verifier.h" + +#include +#include + +// TODO Reduce duplicate code + +namespace onert +{ +namespace ir +{ +namespace train +{ + +UseDefGenerator::UseDefGenerator(const TrainableGraph &tgraph) + : _tgraph{tgraph}, _node_to_idx{}, _training_usedefs{} +{ + const auto order = _tgraph.topolSortOperations(); + for (const auto &index : order) + { + const auto &node = _tgraph.operation(index); + assert(_node_to_idx.find(&node) == _node_to_idx.end()); + _node_to_idx[&node] = index; + } + + // Check whether loss exists + assert(std::any_of(order.begin(), order.end(), + [&](const auto &index) { + return _tgraph.operation(index).opcode() == ir::OpCode::Loss; + }) && + "Loss does not exist"); +} + +UseDefChains UseDefGenerator::operator()() +{ + const auto &graph = _tgraph.graph(); + assert(ir::verifier::EdgeChecker().verify(graph)); + + _training_usedefs.clear(); + graph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) { + // Initialize as emtpy UseDefChain + const auto empty_usedef_chain = UseDefChain{operand}; + _training_usedefs.emplace(TrainingOperandIndex{idx, true}, empty_usedef_chain); + _training_usedefs.emplace(TrainingOperandIndex{idx, false}, empty_usedef_chain); + }); + + initForForwardingNodes(); + + initForBackwardingNodes(); + + return _training_usedefs; +} + +void UseDefGenerator::visit(const train::operation::Loss &node) +{ + assert(_node_to_idx.find(&node) != _node_to_idx.end()); + const auto &op_index = _node_to_idx.at(&node); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + + for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + // Insert use of forwarding inputs + const auto in_forwarding_index = TrainingOperandIndex{in_index, true}; + insertUse(in_forwarding_index, backwarding_op_index); + } + + // Set def of backwarding(backprop) y_pred + const auto &y_pred_index = node.getInputs().at(train::operation::Loss::Input::Y_PRED); + assert(!_tgraph.operands().at(y_pred_index).isConstant()); + const auto y_pred_outgoing_index = TrainingOperandIndex{y_pred_index, false}; + insertBackPropDef(y_pred_outgoing_index, backwarding_op_index); + + // Set def of backwarding(backprop) y_true + const auto &y_true_index = node.getInputs().at(train::operation::Loss::Input::Y_TRUE); + assert(!_tgraph.operands().at(y_true_index).isConstant()); + const auto y_true_outgoing_index = TrainingOperandIndex{y_true_index, false}; + insertBackPropDef(y_true_outgoing_index, backwarding_op_index); + + // Remove use of backwarding output + const auto &out_index = node.getOutputs().at(0); + const auto incoming_index = TrainingOperandIndex{out_index, false}; + auto &usedef_chain = _training_usedefs.at(incoming_index); + usedef_chain.removeTrainingUse(backwarding_op_index); +} + +void UseDefGenerator::insertUse(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index) +{ + assert(_training_usedefs.find(operand_index) != _training_usedefs.end()); + auto &usedef_chain = _training_usedefs.at(operand_index); + usedef_chain.insertTrainingUse(op_index); +} + +void UseDefGenerator::insertDef(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index) +{ + assert(operand_index.valid()); + + assert(_training_usedefs.find(operand_index) != _training_usedefs.end()); + auto &usedef_chain = _training_usedefs.at(operand_index); + usedef_chain.insertTrainingDef(op_index); +} + +void UseDefGenerator::insertBackPropDef(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index) +{ + // NOTE There is no need to set def of constant backwarding(backprop) inputs + // because it won't be back-propagated. + if (!_tgraph.operands().at(operand_index.index()).isConstant()) + { + insertDef(operand_index, op_index); + } +} + +void UseDefGenerator::initForForwardingNodes() +{ + // Initialize training def-uses of forwarding operands for only forwarding nodes + // (i.e. forwarding nodes that do not have any backwarding node) + _tgraph.operands().iterate([&](const ir::OperandIndex &idx, const ir::Operand &operand) { + // Append forwarding def-uses as it is + const bool is_forward = true; + const auto forwarding_operand_index = TrainingOperandIndex{idx, is_forward}; + + const auto def = operand.getDef(); + if (def.valid()) + { + insertDef(forwarding_operand_index, TrainingOperationIndex{def, is_forward}); + auto &usedef_chain = _training_usedefs.at(forwarding_operand_index); + usedef_chain.insertTrainingDef(TrainingOperationIndex{def, is_forward}); + } + + assert(_training_usedefs.at(forwarding_operand_index).getTrainingUses().size() == 0); + const auto uses = operand.getUses(); + for (const auto &use : uses) + insertUse(forwarding_operand_index, TrainingOperationIndex{use, is_forward}); + }); +} + +void UseDefGenerator::initForBackwardingNodes() +{ + const auto backward_order = _tgraph.essentialBackwardOrder(); + // Initialize training uses of forwarding operands and def-uses of backwarding operands for + // backwarding nodes (i.e. backwarding nodes that do not have any forwarding node) + for (const auto &op_index : backward_order) + { + const auto &node = _tgraph.operation(op_index); + + // Insert use of backwarding operands(only output) + { + if (node.getOutputs().size() > 1) + throw std::runtime_error( + "UseDefGenerator does not support multiple outputs of training operation"); + + const auto &output = node.getOutputs().at(0); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + const auto incoming_index = TrainingOperandIndex{output, false}; + insertUse(incoming_index, backwarding_op_index); + } + + // Insert uses of forwarding operands and insert defs of backwarding operands + node.accept(*this); + } +} + +} // namespace train +} // namespace ir +} // namespace onert diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h new file mode 100644 index 00000000000..369d9a22338 --- /dev/null +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ +#define __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__ + +#include "ir/train/TrainableOperationVisitor.h" + +#include "ir/train/UseDefChains.h" +#include "ir/train/Operations.Include.h" + +namespace onert +{ +namespace ir +{ +namespace train +{ +class TrainableGraph; +} // namespace train +} // namespace ir +} // namespace onert + +namespace onert +{ +namespace ir +{ +namespace train +{ + +struct UseDefGeneratorBase : public TrainableOperationVisitor +{ + virtual ~UseDefGeneratorBase() = default; + +protected: +#define OP(InternalName) \ + virtual void visit(const operation::InternalName &) override \ + { \ + throw std::runtime_error("UseDefGenerator: NYI for operation '" #InternalName "'"); \ + } +#include "ir/train/Operations.lst" +#undef OP +}; + +class UseDefGenerator : public UseDefGeneratorBase +{ +public: + UseDefGenerator(void) = delete; + UseDefGenerator(const TrainableGraph &tgraph); + +public: + UseDefChains operator()(); + +public: + void visit(const train::operation::Loss &node) override; + +private: + void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index); + void insertDef(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index); + void insertBackPropDef(const TrainingOperandIndex &operand_index, + const TrainingOperationIndex &op_index); + void initForForwardingNodes(); + void initForBackwardingNodes(); + +private: + const TrainableGraph &_tgraph; + std::unordered_map _node_to_idx; + UseDefChains _training_usedefs; +}; + +} // namespace train +} // namespace ir +} // namespace onert + +#endif // __ONERT_IR_TRAIN_USEDEFINITIALIZER_H__