From 9a5efefae9504cd0523b9a5c004ea5f9a6f34a1c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 17 Oct 2023 10:26:06 +0100 Subject: [PATCH] Fix (export): explicit export path for cifg lstm --- src/brevitas/nn/quant_rnn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index de61ffae7..396a4f6ef 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -23,7 +23,6 @@ from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat -from brevitas.quant_tensor import QuantTensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]] @@ -675,9 +674,8 @@ def forward(self, inp, hidden_state, cell_state): self.output_gate_params, quant_input) if self.cifg: # Avoid dealing with None and set it the same as the forget one - quant_weight_if = quant_weight_ii - quant_weight_hf = quant_weight_hi - quant_bias_forget = quant_bias_input + quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( + self.input_gate_params, quant_input) else: quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( self.forget_gate_params, quant_input)