diff --git a/src/HGQ/quantizer/quantizer.py b/src/HGQ/quantizer/quantizer.py index 82278ca..5387f81 100644 --- a/src/HGQ/quantizer/quantizer.py +++ b/src/HGQ/quantizer/quantizer.py @@ -1,9 +1,10 @@ +from collections.abc import Callable from functools import singledispatchmethod import numpy as np import tensorflow as tf -from ..utils import L1, L1L2, L2, strategy_dict +from ..utils import strategy_dict two = tf.constant(2, dtype=tf.float32) log2 = tf.constant(np.log(2), dtype=tf.float32) @@ -54,7 +55,7 @@ def get_arr_bits(arr: np.ndarray): class HGQ: """Heterogenous quantizer.""" - def __init__(self, init_bw: float, skip_dims, rnd_strategy: str | int = 'floor', exact_q_value=True, dtype=None, bw_clip=(-23, 23), trainable=True, regularizer=None, minmax_record=False): + def __init__(self, init_bw: float, skip_dims, rnd_strategy: str | int = 'floor', exact_q_value=True, dtype=None, bw_clip=(-23, 23), trainable=True, regularizer: Callable | None = None, minmax_record=False): self.init_bw = init_bw self.skip_dims = skip_dims """tuple: Dimensions to use uniform quantizer. If None, use full heterogenous quantizer.""" diff --git a/src/HGQ/utils/__init__.py b/src/HGQ/utils/__init__.py index c07706e..94b2537 100644 --- a/src/HGQ/utils/__init__.py +++ b/src/HGQ/utils/__init__.py @@ -1 +1 @@ -from .utils import L1, L1L2, L2, apf_to_tuple, get_default_kq_conf, get_default_paq_conf, set_default_kq_conf, set_default_paq_conf, strategy_dict, tuple_to_apf, warn +from .utils import MonoL1, apf_to_tuple, get_default_kq_conf, get_default_paq_conf, set_default_kq_conf, set_default_paq_conf, strategy_dict, tuple_to_apf, warn diff --git a/src/HGQ/utils/utils.py b/src/HGQ/utils/utils.py index 08cde05..7902546 100644 --- a/src/HGQ/utils/utils.py +++ b/src/HGQ/utils/utils.py @@ -1,11 +1,12 @@ import re -import sys from warnings import warn as _warn +import keras import tensorflow as tf -class L1: +@keras.utils.register_keras_serializable(package='HGQ') +class MonoL1: def __init__(self, l1=0.): assert l1 >= 0, f'l1 must be non-negative, got {l1}' self.l1 = l1 @@ -17,34 +18,6 @@ def get_config(self): return {'l1': self.l1} -class L2: - def __init__(self, l2=0., zero=-16.): - assert l2 >= 0, f'l2 must be non-negative, got {l2}' - self.l2 = l2 - self.zero = zero - - def __call__(self, x): - return tf.reduce_sum(tf.square(x - self.zero)) * self.l2 - - def get_config(self): - return {'l2': self.l2, 'zero': self.zero} - - -class L1L2: - def __init__(self, l1=0., l2=0., l2_zero=-16.): - assert l1 >= 0, f'l1 must be non-negative, got {l1}' - assert l2 >= 0, f'l2 must be non-negative, got {l2}' - self.l1 = l1 - self.l2 = l2 - self.l2_zero = l2_zero - - def __call__(self, x): - return tf.reduce_sum(x) * self.l1 + tf.reduce_sum(tf.square(x - self.l2_zero)) * self.l2 - - def get_config(self): - return {'l1': self.l1, 'l2': self.l2, 'l2_zero': self.l2_zero} - - DEFAULT_KERNEL_QUANTIZER_CONFIG = \ dict(init_bw=2, skip_dims=None, @@ -53,7 +26,7 @@ def get_config(self): dtype=None, bw_clip=(-23, 23), trainable=True, - regularizer=L1(1e-6), + regularizer=MonoL1(1e-6), ) @@ -65,7 +38,7 @@ def get_config(self): dtype=None, bw_clip=(-23, 23), trainable=True, - regularizer=L1(1e-6), + regularizer=MonoL1(1e-6), minmax_record=True )