Skip to content

Commit

Permalink
UnaryLUT for non-linear activation
Browse files Browse the repository at this point in the history
Unary LUT for non-linear activation

UnaryLUT for non-linear activation
  • Loading branch information
calad0i committed Jan 9, 2024
1 parent 1de2296 commit ef65120
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/HGQ/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .fixed_point_quantizer import FixedPointQuantizer, fixed, gfixed, gfixed_quantizer, ufixed
# Register plugins
from .plugins import init_all
from .unary_lut import UnaryLUT

init_all()
10 changes: 8 additions & 2 deletions src/HGQ/proxy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..utils import warn
from .fixed_point_quantizer import FixedPointQuantizer
from .precision_derivation import register_qconf
from .unary_lut import xfr_to_unary_lut


def get_all_nodes(model: keras.Model) -> set[Node]:
Expand Down Expand Up @@ -218,7 +219,7 @@ def to_keras_layer(layer):


@to_keras_layer.register
def _(layer: HLayerBase | PLayerBase):
def _(layer: ABSBaseLayer):
"""Given a HGQ layer, return the corresponding keras layer.
Example:
Expand All @@ -237,6 +238,9 @@ def _(layer: HLayerBase | PLayerBase):
if 'activation' in conf and conf['activation'] != 'linear':
# Activation will be processed separately for non-activation layers.
conf['activation'] = 'linear'
else:
# Prevent custom activation crashing conversion.
conf['activation'] = layer.activation

cls_name = layer.__class__.__name__[1:]
if hasattr(keras.layers, cls_name):
Expand Down Expand Up @@ -336,7 +340,7 @@ def _(self, layer: keras.layers.Activation):
return layer


def to_proxy_model(model: keras.Model, aggressive: bool = True, accum_fp_max_offset: int | None = None):
def to_proxy_model(model: keras.Model, aggressive: bool = True, accum_fp_max_offset: int | None = None, uniary_lut_max_table_size=-1):
"""Given a HGQ model, return a hls4ml-ready keras model.
Args:
Expand All @@ -353,6 +357,8 @@ def to_proxy_model(model: keras.Model, aggressive: bool = True, accum_fp_max_off
if accum_fp_max_offset is not None and accum_fp_max_offset < 0:
warn('You are using a negative value for bias_accum_bits. Please make sure you know what you are doing.')
proxy = convert_model(model, layer_xformer=ProxyLayerXFormer('WRAP' if aggressive else 'SAT').__call__)
if uniary_lut_max_table_size > 0:
proxy = convert_model(proxy, layer_xformer=partial(xfr_to_unary_lut, max_table_size=uniary_lut_max_table_size))
for layer in proxy.layers:
register_qconf(layer, accum_fp_max_offset=accum_fp_max_offset)
return proxy
2 changes: 2 additions & 0 deletions src/HGQ/proxy/precision_derivation.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def get_whatever_quantizer(layer: keras.layers.Layer):

def register_qconf(layer: keras.layers.Layer, accum_fp_max_offset: None | int = None):
"""Get and register quantization configuration for a layer in the proxy model."""
if hasattr(layer, 'proxy_ready'):
return
q = get_whatever_quantizer(layer)
conf = get_config(layer, accum_fp_max_offset=accum_fp_max_offset)
overrides = q.overrides or {}
Expand Down
102 changes: 102 additions & 0 deletions src/HGQ/proxy/unary_lut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
from collections.abc import Callable

import keras
import numpy as np
import tensorflow as tf
from keras.layers import Layer

from HGQ.proxy.fixed_point_quantizer import gfixed_quantizer
from HGQ.proxy.precision_derivation import get_input_kifs, get_produced_kif, get_result_kifRS
from HGQ.utils import apf_to_tuple, tuple_to_apf

LUT_SIZE_LIMITATION = int(os.environ.get('LUT_SIZE_LIMITATION', 2**20))


class UnaryLUT(Layer):
proxy_ready = True

def __init__(self, kif_in: tuple[int, int, int], kif_out: tuple[int, int, int], RND='TRN', SAT='WRAP', **kwargs):
assert sum(kif_in) > 0, 'Input to activation is constantly zero'
assert sum(kif_out) > 0, 'Output of activation is constantly zero'
if LUT_SIZE_LIMITATION > 0:
assert 2**sum(kif_in) < LUT_SIZE_LIMITATION, f'Input to activation is too large ({2**sum(kif_in)} > {LUT_SIZE_LIMITATION}). If you want to raise this limit, set the LUT_SIZE_LIMITATION environment variable.'
self.kif_in = kif_in
self.kif_out = kif_out
k, i, f = kif_in
self.scale = 2 ** f
self.table = None
if (table := kwargs.pop('table', None)) is not None:
k, i, f, = kif_out
k, b, i = k, k + i + f, k + i
table = gfixed_quantizer(table, k, b, i, RND, SAT) # type:ignore
self.table = tf.Variable(table, dtype='float32', trainable=False, name='table')
super().__init__(**kwargs)

def call(self, inputs, **kwargs):
if not self.built:
self.build(inputs.shape)
inputs = tf.round(inputs * self.scale)
inputs = inputs % self.table.shape[0] # type:ignore
return tf.gather(self.table, tf.cast(inputs, 'int32'))

def build(self, input_shape):
super().build(input_shape)
if self.table is not None:
return
N = 2**sum(self.kif_in)
self.table = tf.Variable(tf.zeros(N), dtype='float32', trainable=False, name='table')

@classmethod
def from_activation(cls, activation: Layer | Callable, kif_in=None, kifRS_out=None):

if kif_in is None:
kifs_in = get_input_kifs(activation)
assert len(kifs_in) == 1, f'Activation function {activation} has more than one input. Please specify the input dtype.'
kif_in = kifs_in[0]

kifRS_out = kifRS_out or get_result_kifRS(activation)
kif_out = kifRS_out[:3]
R, S = kifRS_out[-2:]

k, i, f = kif_in
kif_in = k, i, f
assert k + i + f > 0, 'Activation function is applied to an zero array. Something is wrong.'
N = 2**(k + i + f)
assert N < int(os.environ.get('HLS_MAX_ACTIVATION_LUT_SIZE', 2**16)), f'Input to activation function is too large ({N} > {os.environ.get("HLS_MAX_ACTIVATION_LUT_SIZE", 2**16)}). If you want to raise this limit, set the HLS_MAX_ACTIVATION_LUT_SIZE environment variable.'
if k:
inp_table = np.empty(N, dtype=np.float64)
inp_table[:N // 2] = np.linspace(0, 2.**i - 2.**-f, N // 2, dtype=np.float64)
inp_table[N // 2:] = inp_table[:N // 2] - 2.**i
else:
inp_table = np.linspace(-2.**i * k, 2.**i - 2.**-f, N, dtype=np.float64)
table: np.ndarray = np.array(activation(inp_table), dtype=np.float32)
return cls(kif_in, kif_out, table=table, RND=R, SAT=S)

def get_config(self):
config = super().get_config()
config.update({
'kif_in': self.kif_in,
'kif_out': self.kif_out,
})
return config


def xfr_to_unary_lut(layer: keras.layers.Layer, max_table_size=1024):
if not isinstance(layer, keras.layers.Activation):
return layer
if layer.activation is keras.activations.softmax:
return layer # simply doesn't work
if layer.activation in (keras.activations.relu, keras.activations.linear):
return layer # not necessary
kifs_in = get_input_kifs(layer)
if len(kifs_in) > 1:
return layer
if 2**sum(*kifs_in) > max_table_size:
return layer
kif_in = kifs_in[0]

return UnaryLUT.from_activation(layer, kif_in=kif_in)


get_produced_kif.register(UnaryLUT, lambda x: x.kif_out)
7 changes: 3 additions & 4 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensorflow import keras

from HGQ import trace_minmax
from HGQ.proxy import FixedPointQuantizer, to_proxy_model
from HGQ.proxy import FixedPointQuantizer, UnaryLUT, to_proxy_model

tf.get_logger().setLevel('ERROR')

Expand Down Expand Up @@ -76,8 +76,7 @@ def _run_model_sl_test(model: keras.Model, proxy: keras.Model, data, output_dir:
model_loaded: keras.Model = keras.models.load_model(output_dir + '/keras.h5') # type: ignore

proxy.save(output_dir + '/proxy.h5')
proxy_loaded: keras.Model = keras.models.load_model(output_dir + '/proxy.h5', custom_objects={'FixedPointQuantizer': FixedPointQuantizer}) # type: ignore

proxy_loaded: keras.Model = keras.models.load_model(output_dir + '/proxy.h5', custom_objects={'FixedPointQuantizer': FixedPointQuantizer, 'UnaryLUT': UnaryLUT}) # type: ignore
for l1, l2 in zip(proxy.layers, proxy_loaded.layers):
if not isinstance(l1, FixedPointQuantizer):
continue
Expand Down Expand Up @@ -133,7 +132,7 @@ def run_model_test(model: keras.Model, cover_factor: float | None, data, io_type
_run_gradient_test(model, data)
if cover_factor is not None:
trace_minmax(model, data, cover_factor=cover_factor, bsz=data_len)
proxy = to_proxy_model(model, aggressive=aggressive)
proxy = to_proxy_model(model, aggressive=aggressive, uniary_lut_max_table_size=1024)
try:
if not skip_sl_test:
_run_model_sl_test(model, proxy, data, dir)
Expand Down

0 comments on commit ef65120

Please sign in to comment.