Skip to content

Commit

Permalink
Fix (ptq): fix for ptq_common
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 31, 2023
1 parent 8717175 commit dedb9b8
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,14 @@ def kwargs_prefix(prefix, weight_kwargs):
'return_quant_tensor': False}
# yapf: enable

quant_act_kwargs = {'act_quant': act_quant, 'return_quant_tensor': True}
# For potentially unsigned activations, we create a separate dict
unsigned_quant_act_kwargs = quant_act_kwargs.copy()
if uint_sym_act_for_unsigned_values:
# In case we support unsigned activation, the output of softmax can be unsigned
quant_mha_kwargs['attn_output_weights_signed'] = False
unsigned_quant_act_kwargs['signed'] = False

# Layerwise is basic quant kwargs + input_quant
layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant}

Expand All @@ -374,16 +382,6 @@ def kwargs_prefix(prefix, weight_kwargs):
torch.nn.ConvTranspose1d: (qnn.QuantConvTranspose1d, quant_wbiol_kwargs),
torch.nn.ConvTranspose2d: (qnn.QuantConvTranspose2d, quant_wbiol_kwargs),}

act_quant_and_bit_width = {'act_quant': act_quant, 'bit_width': act_bit_width}
quant_act_kwargs = {**act_quant_and_bit_width, 'return_quant_tensor': True}

# For potentially unsigned activations, we create a separate dict
unsigned_quant_act_kwargs = quant_act_kwargs.copy()
if uint_sym_act_for_unsigned_values:
# In case we support unsigned activation, the output of softmax can be unsigned
quant_mha_kwargs['attn_output_weights_signed'] = False
unsigned_quant_act_kwargs['signed'] = False

quant_act_map = {
torch.nn.ReLU: (qnn.QuantReLU, {
**unsigned_quant_act_kwargs}),
Expand Down

0 comments on commit dedb9b8

Please sign in to comment.