Skip to content

Commit

Permalink
Merge pull request #59 from sony/fix_keras_wrapper
Browse files Browse the repository at this point in the history
Fix Keras Wrapper:
* Remove "convert_to_inferable" from wrapper (moved to MCT wrapper)
* fix _trainable_weights & _non_trainable_weights handling
  • Loading branch information
elad-c authored Nov 8, 2023
2 parents dfb6750 + 01befca commit 16895f2
Showing 1 changed file with 6 additions and 29 deletions.
35 changes: 6 additions & 29 deletions mct_quantizers/keras/quantize_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,13 @@ def _set_weights_vars(self, is_training: bool = True):
weight = getattr(self.layer, name)
quantizer.initialize_quantization(weight.shape, _weight_name(weight.name) if is_training else None,
self)
# Add weight to wrapper weight lists (rather than the layer weight lists), because it will be deleted
# from the layer's lists after the first call
self._weights_vars.append((name, weight, quantizer))
self._trainable_weights.append(weight) # Must when inherit from tf.keras.layers.Wrapper in tf2.10 and below
if is_training and not any([weight is w for w in self._trainable_weights]):
self._trainable_weights.append(weight)
elif not is_training and any([weight is w for w in self._non_trainable_weights]):
self._non_trainable_weights.append(weight)

@classmethod
def from_config(cls, config):
Expand Down Expand Up @@ -238,34 +243,6 @@ def call(self, inputs, training=None, **kwargs):

return outputs

def convert_to_inferable_quantizers(self):
"""
Convert layer's quantizers to inferable.
Returns:
None
"""
# Weight quantizers
inferable_weight_quantizers = {}
if self.is_weights_quantization:
for name, quantizer in self.weights_quantizers.items():
if hasattr(quantizer, 'convert2inferable') and callable(quantizer.convert2inferable):
inferable_weight_quantizers.update({name: quantizer.convert2inferable()})
self.weights_quantizers = inferable_weight_quantizers

# Create new layer with inferable quantizers
inferable_quantizers_wrapper = self.from_config(self.get_config())
inferable_quantizers_wrapper.layer.build(self.get_input_shape_at(0))
layer_weights_list = []
for weight_attr in self.weights_quantizers.keys():
layer_weights_list.append(getattr(self.layer, weight_attr)) # quantized weights
layer_weights_list.extend(self.layer.get_weights()) # non quantized weights
inferable_quantizers_wrapper.layer.set_weights(layer_weights_list)

# The wrapper inference is using the weights of the quantizers so it expectes to create them by running _set_weights_vars
inferable_quantizers_wrapper._set_weights_vars(False)
return inferable_quantizers_wrapper

def get_weights_vars(self) -> List[Tuple[str, Any, BaseInferableQuantizer]]:
"""
A getter of the layer's weights variables.
Expand Down

0 comments on commit 16895f2

Please sign in to comment.