diff --git a/src/brevitas/core/bit_width/parameter.py b/src/brevitas/core/bit_width/parameter.py index 8a54a07df..d5dd63adc 100644 --- a/src/brevitas/core/bit_width/parameter.py +++ b/src/brevitas/core/bit_width/parameter.py @@ -106,7 +106,8 @@ def _load_from_state_dict( del state_dict[bit_width_offset_key] super(BitWidthParameter, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - if config.IGNORE_MISSING_KEYS and bit_width_offset_key in missing_keys: + if (config.IGNORE_MISSING_KEYS or + self.override_pretrained) and bit_width_offset_key in missing_keys: missing_keys.remove(bit_width_offset_key) @@ -147,5 +148,6 @@ def _load_from_state_dict( del state_dict[bit_width_coeff_key] super(RemoveBitwidthParameter, self)._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - if config.IGNORE_MISSING_KEYS and bit_width_coeff_key in missing_keys: + if (config.IGNORE_MISSING_KEYS or + self.override_pretrained) and bit_width_coeff_key in missing_keys: missing_keys.remove(bit_width_coeff_key) diff --git a/tests/brevitas/core/test_bit_width.py b/tests/brevitas/core/test_bit_width.py index e8b1c7879..51883ae18 100644 --- a/tests/brevitas/core/test_bit_width.py +++ b/tests/brevitas/core/test_bit_width.py @@ -142,6 +142,7 @@ def test_load_from_stateful_const( """ if (bit_width_init_two < min_bit_width_init) and not override_pretrained: pytest.xfail('bit_width cannot be smaller than min_bit_width') + override_value = bit_width_parameter.bit_width_offset bit_width_parameter.load_state_dict(bit_width_stateful_const.state_dict()) bit_width_parameter_tensor = bit_width_parameter()