From 36e422aa8e588e973f987babe9f955b00c7ff135 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Thu, 10 Oct 2024 06:31:51 +0900 Subject: [PATCH] [luci/pass] Add origins to const node (#14181) Let's add origins to const node. Signed-off-by: Dayoung Lee --- compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp | 2 ++ compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp | 2 ++ compiler/luci/pass/src/SubstitutePackToReshapePass.cpp | 1 + compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp | 1 + compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp | 1 + 5 files changed, 7 insertions(+) diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp index add55f66c32..b180ce5392d 100644 --- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp +++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp @@ -85,6 +85,7 @@ bool resolve_matmul(luci::CircleCustom *cop) std::swap(perm[circle_lhs->rank() - 1], perm[circle_lhs->rank() - 2]); auto perm_node = luci::create_const_node(graph, S32, {circle_lhs->rank()}, perm); perm_node->name(name + "/lhs/Transpose/perm"); + luci::add_origin(perm_node, luci::get_origin(cop)); // Now make a transpose node auto transpose_node = graph->nodes()->create(); transpose_node->a(lhs); @@ -102,6 +103,7 @@ bool resolve_matmul(luci::CircleCustom *cop) const std::vector perm{1, 0}; auto perm_node = luci::create_const_node(graph, S32, {2}, perm); perm_node->name(name + "/rhs/Transpose/perm"); + luci::add_origin(perm_node, luci::get_origin(cop)); auto transpose_node = graph->nodes()->create(); transpose_node->a(rhs); transpose_node->perm(perm_node); diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp index 92060f6251a..bffaaad0722 100644 --- a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp +++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp @@ -17,6 +17,7 @@ #include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" #include +#include #include #include @@ -86,6 +87,7 @@ luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc) new_weights->dim(r).set(the_weights->dim(r).value()); } new_weights->name(name + "/shuffle_weight"); + luci::add_origin(new_weights, luci::get_origin(the_weights)); // suffle weight const uint32_t MULTIPLE = 16; diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp index d8676cd62b4..188ff496d45 100644 --- a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp @@ -62,6 +62,7 @@ bool substitute_pack_to_reshape(luci::CircleNode *node) const_node->shape_status(luci::ShapeStatus::VALID); const_node->rank(1); const_node->dim(0).set(value_node->rank() + 1); + luci::add_origin(const_node, luci::get_origin(node)); for (int32_t i = 0; i < static_cast(value_node->rank()) + 1; i++) { if (i == axis) diff --git a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp index 9bc764f92ec..43e420d8e4f 100644 --- a/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstituteSqueezeToReshapePass.cpp @@ -134,6 +134,7 @@ bool substitute_squeeze_to_reshape(luci::CircleSqueeze *squeeze) reshape->name(name + "/Reshape"); luci::add_origin(reshape, luci::get_origin(squeeze)); shape_const->name(name + "/Reshape/shape"); + luci::add_origin(shape_const, luci::get_origin(squeeze)); // graph connection reshape->tensor(input); diff --git a/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp index dfd5e6cf2b3..a5381e0b359 100644 --- a/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp +++ b/compiler/luci/pass/src/SubstituteTransposeToReshapePass.cpp @@ -90,6 +90,7 @@ bool substitute_transpose_to_reshape(luci::CircleTranspose *node) new_reshape_node->name(name + "/Reshape"); luci::add_origin(new_reshape_node, luci::get_origin(node)); new_const_node->name(name + "/Reshape/shape"); + luci::add_origin(new_const_node, luci::get_origin(node)); replace(node).with(new_reshape_node); return true;