Skip to content

Commit

Permalink
Better handling of attention mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jul 2, 2023
1 parent e424b86 commit 2a29be1
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
from torch import nn


def attention_mask_handler(attention_mask, query_seq_length, key_value_seq_length):
def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
(implicit batch_size and n_heads)."""
if len(attention_mask.shape) == 4:
attention_mask = attention_mask[0, 0, :, :]
if len(attention_mask.shape) == 3:
attention_mask = attention_mask[0, :, :]
if attention_mask.shape[0] == 1:
if attention_mask.shape[0] == 1:
attention_mask = attention_mask.repeat(batch_size, 1, 1, 1)
if attention_mask.shape[1] == 1:
attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
attention_mask = attention_mask.view(
batch_size * num_heads, query_seq_length, key_value_seq_length)
elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1:
# This could happen in Encoder-like architecture
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask
Expand Down Expand Up @@ -60,14 +67,15 @@ def forward(
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.batch_first:
query_seq_length = hidden_states.shape[1]
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length = hidden_states.shape[0]
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, query_seq_length, key_value_seq_length)
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
Expand Down Expand Up @@ -167,14 +175,14 @@ def forward(
raise RuntimeError("layer_head_mask is not supported.")

if self.mha.batch_first:
query_seq_length = hidden_states.shape[1]
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = encoder_hidden_states.shape[1]
else:
query_seq_length = hidden_states.shape[0]
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = encoder_hidden_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, query_seq_length, key_value_seq_length)

attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
encoder_hidden_states,
Expand Down Expand Up @@ -243,5 +251,11 @@ def set_weight(value):
del state_dict[name]
state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0])
state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape)
# elif prefix + 'self.output.dense.weight' in name:
# state_dict[prefix + 'mha.out_proj.weight'] = value
# del state_dict[name]
# elif prefix + 'self.output.dense.bias' in name:
# state_dict[prefix + 'mha.out_proj.bias'] = value
# del state_dict[name]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

0 comments on commit 2a29be1

Please sign in to comment.