Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (llm): QuantizableBert #645

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 138 additions & 15 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@
from brevitas.utils.torch_utils import KwargsForwardHook


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
(implicit batch_size and n_heads)."""
if len(attention_mask.shape) == 4:
if attention_mask.shape[0] == 1:
attention_mask = attention_mask.repeat(batch_size, 1, 1, 1)
if attention_mask.shape[1] == 1:
attention_mask = attention_mask.repeat(1, num_heads, 1, 1)
if attention_mask.shape[2] == 1:
attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1)
attention_mask = attention_mask.view(
batch_size * num_heads, query_seq_length, key_value_seq_length)
elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1:
# This could happen in Encoder-like architecture
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask


def attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length):
"""Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D
Expand Down Expand Up @@ -76,6 +96,41 @@ def num_heads(self):
def batch_first(self):
return self.wrapped_mha.batch_first


class QuantizableOPTAttention(MultiheadAttentionWrapper):

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down Expand Up @@ -142,36 +197,104 @@ def set_weight(value):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class QuantizableOPTAttention(MultiheadAttentionWrapper):
class QuantizableBertAttention(MultiheadAttentionWrapper):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if encoder_attention_mask is not None:
attention_mask = encoder_attention_mask
if head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.batch_first:

if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
key_value_seq_length = encoder_hidden_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.num_heads
key_value_seq_length = encoder_hidden_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
encoder_hidden_states,
encoder_hidden_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):

def set_bias(value):
bias_name = f'{prefix}mha.in_proj_bias'
if bias_name in state_dict:
state_dict[bias_name] += value
else:
state_dict[bias_name] = value

def set_weight(value):
weight_name = f'{prefix}mha.in_proj_weight'
if weight_name in state_dict:
state_dict[weight_name] += value
else:
state_dict[weight_name] = value

embed_dim = self.mha.embed_dim
for name, value in list(state_dict.items()):
if prefix + 'query.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[:embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'key.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[embed_dim:2 * embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'value.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[2 * embed_dim:3 * embed_dim] = value
set_weight(weight)
del state_dict[name]
if prefix + 'query.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[:embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'key.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[embed_dim:2 * embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'value.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[2 * embed_dim:3 * embed_dim] = value
set_bias(bias)
del state_dict[name]
state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0])
state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape)
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)