Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 9, 2024
1 parent 63b9945 commit b0d7b62
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
5 changes: 1 addition & 4 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,13 @@ def state_dict(self, destination=None, prefix='', keep_vars=False):
@staticmethod
def gate_params_fwd(gate, quant_input):
acc_scale = None
acc_bit_width = None
quant_weight_ih = gate.input_weight()
quant_weight_hh = gate.hidden_weight()
if quant_input.bit_width is not None:
acc_bit_width = None # TODO
if quant_input.scale is not None and quant_weight_ih.scale is not None:
acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1)
acc_scale = quant_weight_ih.scale.view(acc_scale_shape)
acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape)
quant_bias = gate.bias_quant(gate.bias, acc_scale, acc_bit_width)
quant_bias = gate.bias_quant(gate.bias, acc_scale)
return quant_weight_ih, quant_weight_hh, quant_bias

def reset_parameters(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ def scale(self):
if self.requires_input_scale:
return None
zhs = self._zero_hw_sentinel()
scale = self.__call__(self.tracked_parameter_list[0], zhs, zhs).scale
scale = self.__call__(self.tracked_parameter_list[0], zhs).scale
return scale

def zero_point(self):
zhs = self._zero_hw_sentinel()
zero_point = self.__call__(self.tracked_parameter_list[0], zhs, zhs).zero_point
zero_point = self.__call__(self.tracked_parameter_list[0], zhs).zero_point
return zero_point

def bit_width(self):
zhs = self._zero_hw_sentinel()
bit_width = self.__call__(self.tracked_parameter_list[0], zhs, zhs).bit_width
bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width
return bit_width

def forward(self, x: Tensor, input_scale: Optional[Tensor] = None) -> QuantTensor:
Expand Down

0 comments on commit b0d7b62

Please sign in to comment.