Skip to content

Commit

Permalink
Further optimize activation
Browse files Browse the repository at this point in the history
  • Loading branch information
Aba committed Sep 16, 2023
1 parent 2a8a49e commit 3103116
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions test/py/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,18 @@ def apply_act(act_dict):

if act_dict['type'] == 'quant':
shift_bits = self.proc['frac']-frac

x = shift_round(x, shift_bits) # = np.around(x/2**shift_bits)
x = np.clip(x, -2**(bits-1), 2**(bits-1)-1).astype(int)

elif act_dict['type'] == 'relu':
nlog_act_dict = -int(np.log2(act_dict['slope']))
assert nlog_act_dict == -np.log2(act_dict['slope']), f"Leaky Relu slope: {act_dict['slope']} should be a power of two (eg:0.125)"
clip_bits = bits + self.proc['frac']-frac
shift_bits = nlog_act_dict + self.proc['frac']-frac
log_act = -int(np.log2(act_dict['slope']))
assert log_act == -np.log2(act_dict['slope']), f"Leaky Relu slope: {act_dict['slope']} should be a power of two (eg:0.125)"
shift_bits = log_act + self.proc['frac']-frac

x = np.clip(x, -2**(clip_bits-1), 2**(clip_bits-1)-1)
x = (x<0)*x + (x>0)*x *(2**nlog_act_dict)
x = (x<0)*x + (((x>0)*x) << log_act)
x = shift_round(x, shift_bits) # = np.around(x/2**shift_bits)
x = np.clip(x,-2**(bits-1), 2**(bits-1)-1).astype(int)
x = np.clip(x, -2**(bits-log_act-1), 2**(bits-1)-1).astype(int)
else:
raise Exception('Only relu is supported yet')

Expand Down

0 comments on commit 3103116

Please sign in to comment.