Skip to content

Commit

Permalink
Add Permute layer
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Aug 7, 2024
1 parent a4ffb35 commit 412e4a7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 5 additions & 0 deletions src/HGQ/layers/passive_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,8 @@ class PAveragePooling1D(PPool1D, tf.keras.layers.AvgPool1D):
@register_keras_serializable(package="HGQ")
class PDropout(PLayerBase, tf.keras.layers.Dropout):
pass


@register_keras_serializable(package="HGQ")
class PPermute(PLayerBase, tf.keras.layers.Permute):
pass
4 changes: 2 additions & 2 deletions src/HGQ/proxy/precision_derivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import tensorflow as tf
from keras.layers import AvgPool1D, AvgPool2D, AvgPool3D, Concatenate, Flatten, MaxPool1D, MaxPool2D, MaxPool3D, Reshape
from keras.layers import AvgPool1D, AvgPool2D, AvgPool3D, Concatenate, Flatten, MaxPool1D, MaxPool2D, MaxPool3D, Permute, Reshape
from keras.src.layers.convolutional.base_conv import Conv
from keras.src.layers.pooling.base_pooling1d import Pooling1D
from keras.src.layers.pooling.base_pooling2d import Pooling2D
Expand Down Expand Up @@ -159,7 +159,7 @@ def _(layer: keras.layers.Dense | Conv):


@get_produced_kif.register
def _(layer: Reshape | Flatten | MaxPool3D | MaxPool2D | MaxPool1D):
def _(layer: Reshape | Flatten | MaxPool3D | MaxPool2D | MaxPool1D | Permute):
kifs = get_input_kifs(layer)
assert len(kifs) == 1, f'Flatten/Reshape layer {layer.name} has more than one input. This is not supported.'
k, i, f = kifs[0]
Expand Down

0 comments on commit 412e4a7

Please sign in to comment.