|
3 | 3 | import torch
|
4 | 4 | from torch import nn
|
5 | 5 |
|
| 6 | +from brevitas.nn.equalized_layer import EqualizedModule |
| 7 | +from brevitas.utils.torch_utils import KwargsForwardHook |
| 8 | + |
6 | 9 |
|
7 | 10 | def attention_mask_handler(
|
8 | 11 | attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
|
@@ -53,6 +56,26 @@ def __init__(
|
53 | 56 | device,
|
54 | 57 | dtype)
|
55 | 58 |
|
| 59 | + @property |
| 60 | + def wrapped_mha(self): |
| 61 | + mha = self.mha |
| 62 | + # Workaround for activation equalization for when mha is wrapped |
| 63 | + # KwargsForwardHook is inserted during act equalization |
| 64 | + # EqualizedModule is inserted after act equalization |
| 65 | + if isinstance(mha, KwargsForwardHook): |
| 66 | + mha = mha.module |
| 67 | + if isinstance(mha, EqualizedModule): |
| 68 | + mha = mha.layer |
| 69 | + return mha |
| 70 | + |
| 71 | + @property |
| 72 | + def num_heads(self): |
| 73 | + return self.wrapped_mha.num_heads |
| 74 | + |
| 75 | + @property |
| 76 | + def batch_first(self): |
| 77 | + return self.wrapped_mha.batch_first |
| 78 | + |
56 | 79 | def _load_from_state_dict(
|
57 | 80 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
58 | 81 | error_msgs):
|
@@ -134,13 +157,13 @@ def forward(
|
134 | 157 | key_value_states = hidden_states
|
135 | 158 | if layer_head_mask is not None:
|
136 | 159 | raise RuntimeError("layer_head_mask is not supported.")
|
137 |
| - if self.mha.batch_first: |
| 160 | + if self.batch_first: |
138 | 161 | batch_size, query_seq_length = hidden_states.shape[:2]
|
139 | 162 | key_value_seq_length = key_value_states.shape[1]
|
140 | 163 | else:
|
141 | 164 | query_seq_length, batch_size = hidden_states.shape[:2]
|
142 | 165 | key_value_seq_length = key_value_states.shape[0]
|
143 |
| - num_heads = self.mha.num_heads |
| 166 | + num_heads = self.num_heads |
144 | 167 | attention_mask = attention_mask_handler(
|
145 | 168 | attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
|
146 | 169 | attn_output, attn_output_weights = self.mha(
|
|
0 commit comments