From 412e4a7ae01b64c1d49e33771e46b41f40a49170 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 19 Jul 2024 19:22:01 -0700 Subject: [PATCH] Add Permute layer --- src/HGQ/layers/passive_layers.py | 5 +++++ src/HGQ/proxy/precision_derivation.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/HGQ/layers/passive_layers.py b/src/HGQ/layers/passive_layers.py index aa5170b..28ef324 100644 --- a/src/HGQ/layers/passive_layers.py +++ b/src/HGQ/layers/passive_layers.py @@ -175,3 +175,8 @@ class PAveragePooling1D(PPool1D, tf.keras.layers.AvgPool1D): @register_keras_serializable(package="HGQ") class PDropout(PLayerBase, tf.keras.layers.Dropout): pass + + +@register_keras_serializable(package="HGQ") +class PPermute(PLayerBase, tf.keras.layers.Permute): + pass diff --git a/src/HGQ/proxy/precision_derivation.py b/src/HGQ/proxy/precision_derivation.py index fdcc403..9f2096c 100644 --- a/src/HGQ/proxy/precision_derivation.py +++ b/src/HGQ/proxy/precision_derivation.py @@ -3,7 +3,7 @@ import numpy as np import tensorflow as tf -from keras.layers import AvgPool1D, AvgPool2D, AvgPool3D, Concatenate, Flatten, MaxPool1D, MaxPool2D, MaxPool3D, Reshape +from keras.layers import AvgPool1D, AvgPool2D, AvgPool3D, Concatenate, Flatten, MaxPool1D, MaxPool2D, MaxPool3D, Permute, Reshape from keras.src.layers.convolutional.base_conv import Conv from keras.src.layers.pooling.base_pooling1d import Pooling1D from keras.src.layers.pooling.base_pooling2d import Pooling2D @@ -159,7 +159,7 @@ def _(layer: keras.layers.Dense | Conv): @get_produced_kif.register -def _(layer: Reshape | Flatten | MaxPool3D | MaxPool2D | MaxPool1D): +def _(layer: Reshape | Flatten | MaxPool3D | MaxPool2D | MaxPool1D | Permute): kifs = get_input_kifs(layer) assert len(kifs) == 1, f'Flatten/Reshape layer {layer.name} has more than one input. This is not supported.' k, i, f = kifs[0]