diff --git a/src/HGQ/proxy/fixed_point_quantizer.py b/src/HGQ/proxy/fixed_point_quantizer.py index 97ad5d7..6588204 100644 --- a/src/HGQ/proxy/fixed_point_quantizer.py +++ b/src/HGQ/proxy/fixed_point_quantizer.py @@ -209,7 +209,9 @@ def heterogeneous(self): def get_config(self): assert tf.reduce_all((self.keep_negative == 0) | (self.keep_negative == 1)), 'Illegal bitwidth config: keep_negative must be 0 or 1.' - assert tf.reduce_all(self.bits >= 0), 'Illegal bitwidth config: bits must be non-negative.' # type:ignore + if not tf.reduce_all(self.bits >= 0): # type:ignore + warn('Illegal bitwidth config: bits must be non-negative.') + self.bits.assign(tf.maximum(self.bits, 0)) conf = super().get_config() conf['RND'] = self.RND conf['SAT'] = self.SAT