diff --git a/src/HGQ/proxy/fixed_point_quantizer.py b/src/HGQ/proxy/fixed_point_quantizer.py index 6588204..4518895 100644 --- a/src/HGQ/proxy/fixed_point_quantizer.py +++ b/src/HGQ/proxy/fixed_point_quantizer.py @@ -143,6 +143,7 @@ def fixed(bits, integer_bits, RND='TRN', SAT='WRAP') -> Callable: return gfixed(1, bits, integer_bits, RND, SAT) +@keras.utils.register_keras_serializable(package='HGQ') class FixedPointQuantizer(keras.layers.Layer, metaclass=abc.ABCMeta): def __init__(self, keep_negative, bits, integers, RND: str = 'TRN', SAT: str = 'WRAP', overrides: dict | None = None, **kwargs): @@ -174,7 +175,7 @@ def __init__(self, keep_negative, bits, integers, RND: str = 'TRN', SAT: str = ' super().__init__(trainable=False, **kwargs) - def call(self, x, training=None): + def call(self, x, training=None): # type:ignore assert not training, "Proxy model shall can not be trained!" if not self.built: self.build(x.shape) diff --git a/src/HGQ/proxy/unary_lut.py b/src/HGQ/proxy/unary_lut.py index 76e529a..93a7f72 100644 --- a/src/HGQ/proxy/unary_lut.py +++ b/src/HGQ/proxy/unary_lut.py @@ -13,6 +13,7 @@ LUT_SIZE_LIMITATION = int(os.environ.get('LUT_SIZE_LIMITATION', 2**12)) +@keras.utils.register_keras_serializable(package='HGQ') class UnaryLUT(Layer): proxy_ready = True @@ -33,7 +34,7 @@ def __init__(self, kif_in: tuple[int, int, int], kif_out: tuple[int, int, int], self.table = tf.Variable(table, dtype='float32', trainable=False, name='table') super().__init__(**kwargs) - def call(self, inputs, **kwargs): + def call(self, inputs, **kwargs): # type:ignore if not self.built: self.build(inputs.shape) inputs = tf.round(inputs * self.scale) diff --git a/test/helpers.py b/test/helpers.py index e7bf442..44bc21f 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -16,7 +16,7 @@ from tensorflow import keras from HGQ import trace_minmax -from HGQ.proxy import FixedPointQuantizer, UnaryLUT, to_proxy_model +from HGQ.proxy import FixedPointQuantizer, to_proxy_model tf.get_logger().setLevel('ERROR') @@ -76,7 +76,7 @@ def _run_model_sl_test(model: keras.Model, proxy: keras.Model, data, output_dir: model_loaded: keras.Model = keras.models.load_model(output_dir + '/keras.h5') # type: ignore proxy.save(output_dir + '/proxy.h5') - proxy_loaded: keras.Model = keras.models.load_model(output_dir + '/proxy.h5', custom_objects={'FixedPointQuantizer': FixedPointQuantizer, 'UnaryLUT': UnaryLUT}) # type: ignore + proxy_loaded: keras.Model = keras.models.load_model(output_dir + '/proxy.h5') # type: ignore for l1, l2 in zip(proxy.layers, proxy_loaded.layers): if not isinstance(l1, FixedPointQuantizer): continue