diff --git a/src/brevitas/nn/quant_rnn.py b/src/brevitas/nn/quant_rnn.py index de61ffae7..396a4f6ef 100644 --- a/src/brevitas/nn/quant_rnn.py +++ b/src/brevitas/nn/quant_rnn.py @@ -23,7 +23,6 @@ from brevitas.quant import Int8WeightPerTensorFloat from brevitas.quant import Int32Bias from brevitas.quant import Uint8ActPerTensorFloat -from brevitas.quant_tensor import QuantTensor QuantTupleShortEnabled = List[Tuple[Tensor, Tensor, Tensor, Tensor]] QuantTupleShortDisabled = List[Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]] @@ -675,9 +674,8 @@ def forward(self, inp, hidden_state, cell_state): self.output_gate_params, quant_input) if self.cifg: # Avoid dealing with None and set it the same as the forget one - quant_weight_if = quant_weight_ii - quant_weight_hf = quant_weight_hi - quant_bias_forget = quant_bias_input + quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( + self.input_gate_params, quant_input) else: quant_weight_if, quant_weight_hf, quant_bias_forget = self.gate_params_fwd( self.forget_gate_params, quant_input)