Skip to content

Commit

Permalink
pass parallelization_factor, rename rule change
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Jan 26, 2024
1 parent 100e35a commit 77c882e
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/HGQ/proxy/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ def _get_out_tensor(
if len(inps) == 1:
inps = inps[0]

name = namer.next_name(layer.name)
layer._name = name
satisfied[node] = apply_layer(layer, inps, namer=namer, layer_xformer=layer_xformer)
return satisfied[node]

Expand All @@ -149,6 +147,7 @@ def apply_layer(

layer_xf = layer_xformer(layer)
n = 0
namer = namer or Namer()
while layer_xf is not layer:
if isinstance(layer_xf, keras.Model):
# layer_transformer may return a keras.Model from a layer.
Expand All @@ -157,6 +156,10 @@ def apply_layer(
layer = layer_xf
layer_xf = layer_xformer(layer)
n += 1

name = namer.next_name(layer_xf.name)
layer_xf._name = name

if layer_xf is not None:
# Remove all inbound and outbound nodes to clean up the graph.
layer_xf = copy(layer_xf)
Expand Down Expand Up @@ -306,6 +309,11 @@ def _(self, layer: HLayerBase):
else:
inputs = [keras.layers.Input(shape=shape[1:]) for shape in input_shape]

overrides = {'layers': {}}
if hasattr(layer, 'parallel_factor'):
parallel_factor = layer.parallel_factor
overrides = {'layers': {name: {'parallelization_factor': int(parallel_factor)}}}

if not isinstance(klayer, keras.layers.Activation):
if layer._relu_act: # a non-activation layer can only have relu or liner as activation
k, i, f = layer.paq.get_bits_exact(pos_only=False)
Expand All @@ -315,11 +323,11 @@ def _(self, layer: HLayerBase):

f_add_bits = 0 if R == 'TRN' else 1 if R == 'RND' else 2
fq1 = FixedPointQuantizer(k, b + f_add_bits, i, SAT=S, RND='TRN', name=f'{name}_quantizer')
fq2 = FixedPointQuantizer(rk, rb, ri, SAT=S, RND=R, name=f'{name}_relu_quantizer')
fq2 = FixedPointQuantizer(rk, rb, ri, SAT=S, RND=R, name=f'{name}_relu_quantizer', overrides=overrides)
return keras.Model(inputs, fq2(keras.layers.ReLU()(fq1(klayer(inputs)))))

k, b, i, R, S = self.get_kbiRS(layer)
q = FixedPointQuantizer(k, b, i, R, SAT=S, name=f'{name}_quantizer')
q = FixedPointQuantizer(k, b, i, R, SAT=S, name=f'{name}_quantizer', overrides=overrides)
return keras.Model(inputs, q(klayer(inputs)))

@__call__.register
Expand Down

0 comments on commit 77c882e

Please sign in to comment.