Skip to content

Commit 2a10ee2

Browse files
committed
Fix errors after latest PR
1 parent b6ed36e commit 2a10ee2

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

src/brevitas_examples/llm/llm_quant/mha_layers.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import torch
44
from torch import nn
55

6+
from brevitas.nn.equalized_layer import EqualizedModule
7+
from brevitas.utils.torch_utils import KwargsForwardHook
8+
69

710
def attention_mask_handler(
811
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
@@ -53,6 +56,26 @@ def __init__(
5356
device,
5457
dtype)
5558

59+
@property
60+
def wrapped_mha(self):
61+
mha = self.mha
62+
# Workaround for activation equalization for when mha is wrapped
63+
# KwargsForwardHook is inserted during act equalization
64+
# EqualizedModule is inserted after act equalization
65+
if isinstance(mha, KwargsForwardHook):
66+
mha = mha.module
67+
if isinstance(mha, EqualizedModule):
68+
mha = mha.layer
69+
return mha
70+
71+
@property
72+
def num_heads(self):
73+
return self.wrapped_mha.num_heads
74+
75+
@property
76+
def batch_first(self):
77+
return self.wrapped_mha.batch_first
78+
5679
def _load_from_state_dict(
5780
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
5881
error_msgs):
@@ -134,13 +157,13 @@ def forward(
134157
key_value_states = hidden_states
135158
if layer_head_mask is not None:
136159
raise RuntimeError("layer_head_mask is not supported.")
137-
if self.mha.batch_first:
160+
if self.batch_first:
138161
batch_size, query_seq_length = hidden_states.shape[:2]
139162
key_value_seq_length = key_value_states.shape[1]
140163
else:
141164
query_seq_length, batch_size = hidden_states.shape[:2]
142165
key_value_seq_length = key_value_states.shape[0]
143-
num_heads = self.mha.num_heads
166+
num_heads = self.num_heads
144167
attention_mask = attention_mask_handler(
145168
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
146169
attn_output, attn_output_weights = self.mha(

src/brevitas_examples/llm/llm_quant/quantize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def quantize_model(
268268
'q_scaled_quant': q_scaled_quant,
269269
'k_transposed_quant': k_transposed_quant,
270270
'v_quant': v_quant,
271-
'out_proj_input_quant': linear_2d_input_quant,
271+
'out_proj_input_quant': input_quant,
272272
'out_proj_weight_quant': weight_quant,
273273
'out_proj_bias_quant': None,
274274
'out_proj_output_quant': None,

src/brevitas_examples/llm/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
default='stats',
7979
choices=['stats', 'mse'],
8080
help=
81-
'How scales/zero-point are determined. Default: stats (percentile for static, absmax minmax for dynamic).'
81+
'How scales/zero-point are determined. Default: stats (percentile for static, absmax or minmax for dynamic).'
8282
)
8383
parser.add_argument(
8484
'--input-scale-precision',
@@ -89,7 +89,7 @@
8989
parser.add_argument(
9090
'--input-scale-type',
9191
type=str,
92-
default='float',
92+
default='static',
9393
choices=['static', 'dynamic'],
9494
help='Whether input scale is a static value or a dynamic value.')
9595
parser.add_argument(

0 commit comments

Comments
 (0)