diff --git a/src/HGQ/proxy/precision_derivation.py b/src/HGQ/proxy/precision_derivation.py index 085102a..74c6ac6 100644 --- a/src/HGQ/proxy/precision_derivation.py +++ b/src/HGQ/proxy/precision_derivation.py @@ -194,7 +194,7 @@ def get_requested_kif(layer: keras.layers.Layer | FixedPointQuantizer) -> tuple[ @singledispatch def get_request_kif(layer: keras.layers.Layer) -> tuple[int, int, int]: """Get the requested bitwidth of a layer, as a tuple of (k, i, f)""" - if isinstance(layer, (Pooling1D, Pooling2D, Pooling3D, Concatenate, Reshape, Flatten)): + if isinstance(layer, (Concatenate, Reshape, Flatten)): out_layers: list[keras.layers.Layer] = [node.outbound_layer for node in layer._outbound_nodes] if out_layers: # Layers that does nothing. Pass through.