Skip to content

Commit

Permalink
allow nested layers
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jun 17, 2024
1 parent 2b85706 commit eafebc7
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/HGQ/bops/bops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,29 @@ def __init__(self):

def on_epoch_end(self, epoch, logs=None):
assert self.model is not None
logs['bops'] = self._bops(self.model) # type: ignore

def _bops(self, model):
bops = 0
for layer in self.model.layers:
for layer in model.layers:
if hasattr(layer, 'bops'):
bops += layer.bops.numpy()
logs['bops'] = bops # type: ignore
continue
if isinstance(layer, keras.Model):
bops += self._bops(layer)
return bops


class ResetMinMax(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
assert self.model is not None
for layer in self.model.layers:
self.reset_minmax(self.model)

@staticmethod
def reset_minmax(model):
for layer in model.layers:
if isinstance(layer, keras.Model):
ResetMinMax.reset_minmax(layer)
if isinstance(layer, HLayerBase):
layer.reset_minmax()

Expand Down

0 comments on commit eafebc7

Please sign in to comment.