Skip to content

Commit

Permalink
redone regs for bw
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Nov 30, 2023
1 parent 72d52fd commit 1bc96de
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 35 deletions.
5 changes: 3 additions & 2 deletions src/HGQ/quantizer/quantizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import Callable
from functools import singledispatchmethod

import numpy as np
import tensorflow as tf

from ..utils import L1, L1L2, L2, strategy_dict
from ..utils import strategy_dict

two = tf.constant(2, dtype=tf.float32)
log2 = tf.constant(np.log(2), dtype=tf.float32)
Expand Down Expand Up @@ -54,7 +55,7 @@ def get_arr_bits(arr: np.ndarray):
class HGQ:
"""Heterogenous quantizer."""

def __init__(self, init_bw: float, skip_dims, rnd_strategy: str | int = 'floor', exact_q_value=True, dtype=None, bw_clip=(-23, 23), trainable=True, regularizer=None, minmax_record=False):
def __init__(self, init_bw: float, skip_dims, rnd_strategy: str | int = 'floor', exact_q_value=True, dtype=None, bw_clip=(-23, 23), trainable=True, regularizer: Callable | None = None, minmax_record=False):
self.init_bw = init_bw
self.skip_dims = skip_dims
"""tuple: Dimensions to use uniform quantizer. If None, use full heterogenous quantizer."""
Expand Down
2 changes: 1 addition & 1 deletion src/HGQ/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .utils import L1, L1L2, L2, apf_to_tuple, get_default_kq_conf, get_default_paq_conf, set_default_kq_conf, set_default_paq_conf, strategy_dict, tuple_to_apf, warn
from .utils import MonoL1, apf_to_tuple, get_default_kq_conf, get_default_paq_conf, set_default_kq_conf, set_default_paq_conf, strategy_dict, tuple_to_apf, warn
37 changes: 5 additions & 32 deletions src/HGQ/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import re
import sys
from warnings import warn as _warn

import keras
import tensorflow as tf


class L1:
@keras.utils.register_keras_serializable(package='HGQ')
class MonoL1:
def __init__(self, l1=0.):
assert l1 >= 0, f'l1 must be non-negative, got {l1}'
self.l1 = l1
Expand All @@ -17,34 +18,6 @@ def get_config(self):
return {'l1': self.l1}


class L2:
def __init__(self, l2=0., zero=-16.):
assert l2 >= 0, f'l2 must be non-negative, got {l2}'
self.l2 = l2
self.zero = zero

def __call__(self, x):
return tf.reduce_sum(tf.square(x - self.zero)) * self.l2

def get_config(self):
return {'l2': self.l2, 'zero': self.zero}


class L1L2:
def __init__(self, l1=0., l2=0., l2_zero=-16.):
assert l1 >= 0, f'l1 must be non-negative, got {l1}'
assert l2 >= 0, f'l2 must be non-negative, got {l2}'
self.l1 = l1
self.l2 = l2
self.l2_zero = l2_zero

def __call__(self, x):
return tf.reduce_sum(x) * self.l1 + tf.reduce_sum(tf.square(x - self.l2_zero)) * self.l2

def get_config(self):
return {'l1': self.l1, 'l2': self.l2, 'l2_zero': self.l2_zero}


DEFAULT_KERNEL_QUANTIZER_CONFIG = \
dict(init_bw=2,
skip_dims=None,
Expand All @@ -53,7 +26,7 @@ def get_config(self):
dtype=None,
bw_clip=(-23, 23),
trainable=True,
regularizer=L1(1e-6),
regularizer=MonoL1(1e-6),
)


Expand All @@ -65,7 +38,7 @@ def get_config(self):
dtype=None,
bw_clip=(-23, 23),
trainable=True,
regularizer=L1(1e-6),
regularizer=MonoL1(1e-6),
minmax_record=True
)

Expand Down

0 comments on commit 1bc96de

Please sign in to comment.