Skip to content

Commit

Permalink
regiser all layers serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Apr 26, 2024
1 parent f67870b commit 02d4b73
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/HGQ/proxy/fixed_point_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def fixed(bits, integer_bits, RND='TRN', SAT='WRAP') -> Callable:
return gfixed(1, bits, integer_bits, RND, SAT)


@keras.utils.register_keras_serializable(package='HGQ')
class FixedPointQuantizer(keras.layers.Layer, metaclass=abc.ABCMeta):

def __init__(self, keep_negative, bits, integers, RND: str = 'TRN', SAT: str = 'WRAP', overrides: dict | None = None, **kwargs):
Expand Down Expand Up @@ -174,7 +175,7 @@ def __init__(self, keep_negative, bits, integers, RND: str = 'TRN', SAT: str = '

super().__init__(trainable=False, **kwargs)

def call(self, x, training=None):
def call(self, x, training=None): # type:ignore
assert not training, "Proxy model shall can not be trained!"
if not self.built:
self.build(x.shape)
Expand Down
3 changes: 2 additions & 1 deletion src/HGQ/proxy/unary_lut.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LUT_SIZE_LIMITATION = int(os.environ.get('LUT_SIZE_LIMITATION', 2**12))


@keras.utils.register_keras_serializable(package='HGQ')
class UnaryLUT(Layer):
proxy_ready = True

Expand All @@ -33,7 +34,7 @@ def __init__(self, kif_in: tuple[int, int, int], kif_out: tuple[int, int, int],
self.table = tf.Variable(table, dtype='float32', trainable=False, name='table')
super().__init__(**kwargs)

def call(self, inputs, **kwargs):
def call(self, inputs, **kwargs): # type:ignore
if not self.built:
self.build(inputs.shape)
inputs = tf.round(inputs * self.scale)
Expand Down
4 changes: 2 additions & 2 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensorflow import keras

from HGQ import trace_minmax
from HGQ.proxy import FixedPointQuantizer, UnaryLUT, to_proxy_model
from HGQ.proxy import FixedPointQuantizer, to_proxy_model

tf.get_logger().setLevel('ERROR')

Expand Down Expand Up @@ -76,7 +76,7 @@ def _run_model_sl_test(model: keras.Model, proxy: keras.Model, data, output_dir:
model_loaded: keras.Model = keras.models.load_model(output_dir + '/keras.h5') # type: ignore

proxy.save(output_dir + '/proxy.h5')
proxy_loaded: keras.Model = keras.models.load_model(output_dir + '/proxy.h5', custom_objects={'FixedPointQuantizer': FixedPointQuantizer, 'UnaryLUT': UnaryLUT}) # type: ignore
proxy_loaded: keras.Model = keras.models.load_model(output_dir + '/proxy.h5') # type: ignore
for l1, l2 in zip(proxy.layers, proxy_loaded.layers):
if not isinstance(l1, FixedPointQuantizer):
continue
Expand Down

0 comments on commit 02d4b73

Please sign in to comment.