diff --git a/src/HGQ/bops/bops.py b/src/HGQ/bops/bops.py index b83172d..a681cb3 100644 --- a/src/HGQ/bops/bops.py +++ b/src/HGQ/bops/bops.py @@ -12,17 +12,29 @@ def __init__(self): def on_epoch_end(self, epoch, logs=None): assert self.model is not None + logs['bops'] = self._bops(self.model) # type: ignore + + def _bops(self, model): bops = 0 - for layer in self.model.layers: + for layer in model.layers: if hasattr(layer, 'bops'): bops += layer.bops.numpy() - logs['bops'] = bops # type: ignore + continue + if isinstance(layer, keras.Model): + bops += self._bops(layer) + return bops class ResetMinMax(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): assert self.model is not None - for layer in self.model.layers: + self.reset_minmax(self.model) + + @staticmethod + def reset_minmax(model): + for layer in model.layers: + if isinstance(layer, keras.Model): + ResetMinMax.reset_minmax(layer) if isinstance(layer, HLayerBase): layer.reset_minmax()