Skip to content

Commit

Permalink
Feat (llm): QuantizableBert
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 30, 2023
1 parent b41b614 commit e424b86
Showing 1 changed file with 133 additions and 13 deletions.
146 changes: 133 additions & 13 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
from torch import nn


def attention_mask_handler(attention_mask, query_seq_length, key_value_seq_length):
if len(attention_mask.shape) == 4:
attention_mask = attention_mask[0, 0, :, :]
if len(attention_mask.shape) == 3:
attention_mask = attention_mask[0, :, :]
if attention_mask.shape[0] == 1:
assert query_seq_length == key_value_seq_length
attention_mask = attention_mask.repeat(query_seq_length, 1)
return attention_mask


class MultiheadAttentionWrapper(nn.Module):

def __init__(
Expand Down Expand Up @@ -33,6 +44,40 @@ def __init__(
device,
dtype)


class QuantizableOPTAttention(MultiheadAttentionWrapper):

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.batch_first:
query_seq_length = hidden_states.shape[1]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length = hidden_states.shape[0]
key_value_seq_length = key_value_states.shape[0]
attention_mask = attention_mask_handler(
attention_mask, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down Expand Up @@ -99,29 +144,104 @@ def set_weight(value):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class QuantizableOPTAttention(MultiheadAttentionWrapper):
class QuantizableBertAttention(MultiheadAttentionWrapper):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if encoder_attention_mask is not None:
attention_mask = encoder_attention_mask
if head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if attention_mask is not None:
attention_mask = attention_mask.squeeze()

if self.mha.batch_first:
query_seq_length = hidden_states.shape[1]
key_value_seq_length = encoder_hidden_states.shape[1]
else:
query_seq_length = hidden_states.shape[0]
key_value_seq_length = encoder_hidden_states.shape[0]
attention_mask = attention_mask_handler(
attention_mask, query_seq_length, key_value_seq_length)

attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
encoder_hidden_states,
encoder_hidden_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):

def set_bias(value):
bias_name = f'{prefix}mha.in_proj_bias'
if bias_name in state_dict:
state_dict[bias_name] += value
else:
state_dict[bias_name] = value

def set_weight(value):
weight_name = f'{prefix}mha.in_proj_weight'
if weight_name in state_dict:
state_dict[weight_name] += value
else:
state_dict[weight_name] = value

embed_dim = self.mha.embed_dim
for name, value in list(state_dict.items()):
if prefix + 'query.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[:embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'key.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[embed_dim:2 * embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'value.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[2 * embed_dim:3 * embed_dim] = value
set_weight(weight)
del state_dict[name]
if prefix + 'query.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[:embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'key.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[embed_dim:2 * embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'value.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[2 * embed_dim:3 * embed_dim] = value
set_bias(bias)
del state_dict[name]
state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0])
state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape)
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

0 comments on commit e424b86

Please sign in to comment.