From b480c5614773b1c7c657e2d0125d5a8cd1daf177 Mon Sep 17 00:00:00 2001 From: SeungHui Youn <61981457+zetwhite@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:44:34 +0900 Subject: [PATCH] [onert/train] Register LayerScopeTensor to registry (#14235) This PR registers LayerScopeTensor from each layer into tensor registry. ONE-DCO-1.0-Signed-off-by: seunghui youn -------------------------------------- draft : https://github.com/Samsung/ONE/pull/13486 --- runtime/onert/backend/train/BackendContext.cc | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index 446fc68c244..d0a634b1460 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -262,7 +262,43 @@ FunctionMap BackendContext::generateFunctionMap() void BackendContext::planLayerScopeTensors([[maybe_unused]] const FunctionMap &fn_map) { - // TODO: Register LayerScopeTensors + const auto &ops = trainable_graph()->operations(); + + auto register_tensors = [this](const ir::OperationIndex &op_idx, + std::optional &&tensors) { + if (not tensors.has_value()) + return; + + auto ls_tensors = tensors.value(); + for (auto i = 0u; i < ls_tensors.size(); ++i) + { + LayerScopeTensorIndex tensor_idx(op_idx, i); + _tensor_builder->registerLayerScopeTensor(tensor_idx, ls_tensors[i]); + + VERBOSE(BackendContext) << "(idx:" << tensor_idx << ") registered" << std::endl; + } + return; + }; + + for (auto &pair : fn_map) + { + const auto &op_idx = pair.first; + auto &fn_seq = pair.second; + + const ir::IOperation *op = &ops.at(op_idx); + const auto trainable_op = dynamic_cast(op); + assert(trainable_op != nullptr); + + if (not trainable_op->isRequiredForBackward()) + continue; + + VERBOSE(BackendContext) << "register layerscope tensor for " << trainable_op->name() + << std::endl; + + fn_seq->iterate([&](exec::train::ITrainableFunction &fn) { + register_tensors(op_idx, (&fn)->registerLayerScopeTensors()); + }); + } const auto ctx_data = data(); TensorPlanner tensor_planner{*ctx_data->tgraph.get(), ctx_data->external_operands};