From afdbc7c3a394c92cf10f475f3178808c61e11e14 Mon Sep 17 00:00:00 2001 From: Hyukjin Jeong Date: Fri, 20 Dec 2024 16:29:22 +0900 Subject: [PATCH] [luci] Propagate qparam backward in onnx-fake quant model This propagates qparam backward in onnx-fake quant model. ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong --- compiler/luci/pass/src/PropagateQParamBackwardPass.cpp | 10 ++-------- .../luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp | 7 +++++++ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp index 18617e3b77b..da115b06dcf 100644 --- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp @@ -107,6 +107,8 @@ void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *targ target_qparam->scale = source_qparam->scale; target_qparam->zerop = source_qparam->zerop; target_qparam->quantized_dimension = source_qparam->quantized_dimension; + + target->dtype(source->dtype()); } /** @@ -188,8 +190,6 @@ void propagate_pack_quantparam(luci::CirclePack *pack) if (succs.size() > 1) continue; - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); overwrite_quantparam(pack, node); } } @@ -260,8 +260,6 @@ void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot) if (succs.size() > 1) return; - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); overwrite_quantparam(one_hot, node); } }; @@ -340,8 +338,6 @@ void propagate_concat_quantparam(luci::CircleConcatenation *concat) if (succs.size() > 1) continue; - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); overwrite_quantparam(concat, node); } } @@ -440,8 +436,6 @@ void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2) if (succs.size() > 1) return; - // Non-const input must have been quantized - assert(node->quantparam() != nullptr); overwrite_quantparam(pad_v2, node); } }; diff --git a/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp index face706b2cc..b95c7ea150c 100644 --- a/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp +++ b/compiler/luci/pass/src/QuantizeOnnxFakeQuantModelPass.cpp @@ -14,6 +14,7 @@ */ #include "luci/Pass/QuantizeOnnxFakeQuantModelPass.h" +#include "luci/Pass/PropagateQParamBackwardPass.h" #include "QuantizeOnnxQDQPass.h" #include "QuantizeOnnxDequantizeLinearPass.h" #include "QuantizeWithPredecessorPass.h" @@ -92,6 +93,12 @@ bool QuantizeOnnxFakeQuantModelPass::run(loco::Graph *g) pass.run(g); } + // Backward propagation of activation qparam + { + PropagateQParamBackwardPass pqbp(_ctx->default_activation_dtype); + pqbp.run(g); + } + // Update qparam of output of special Ops for (auto node : loco::active_nodes(loco::output_nodes(g))) {