diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index a25733864..de8b88808 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -4,6 +4,17 @@ from torch import nn +def attention_mask_handler(attention_mask, query_seq_length, key_value_seq_length): + if len(attention_mask.shape) == 4: + attention_mask = attention_mask[0, 0, :, :] + if len(attention_mask.shape) == 3: + attention_mask = attention_mask[0, :, :] + if attention_mask.shape[0] == 1: + assert query_seq_length == key_value_seq_length + attention_mask = attention_mask.repeat(query_seq_length, 1) + return attention_mask + + class MultiheadAttentionWrapper(nn.Module): def __init__( @@ -33,6 +44,40 @@ def __init__( device, dtype) + +class QuantizableOPTAttention(MultiheadAttentionWrapper): + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if key_value_states is None: + key_value_states = hidden_states + if layer_head_mask is not None: + raise RuntimeError("layer_head_mask is not supported.") + if self.batch_first: + query_seq_length = hidden_states.shape[1] + key_value_seq_length = key_value_states.shape[1] + else: + query_seq_length = hidden_states.shape[0] + key_value_seq_length = key_value_states.shape[0] + attention_mask = attention_mask_handler( + attention_mask, query_seq_length, key_value_seq_length) + attn_output, attn_output_weights = self.mha( + hidden_states, + key_value_states, + key_value_states, + attn_mask=attention_mask, + need_weights=output_attentions, + average_attn_weights=False) + past_key_value = None + return attn_output, attn_output_weights, past_key_value + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -99,29 +144,104 @@ def set_weight(value): state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -class QuantizableOPTAttention(MultiheadAttentionWrapper): +class QuantizableBertAttention(MultiheadAttentionWrapper): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if key_value_states is None: - key_value_states = hidden_states - if layer_head_mask is not None: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + if encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + if head_mask is not None: raise RuntimeError("layer_head_mask is not supported.") - if attention_mask is not None: - attention_mask = attention_mask.squeeze() + + if self.mha.batch_first: + query_seq_length = hidden_states.shape[1] + key_value_seq_length = encoder_hidden_states.shape[1] + else: + query_seq_length = hidden_states.shape[0] + key_value_seq_length = encoder_hidden_states.shape[0] + attention_mask = attention_mask_handler( + attention_mask, query_seq_length, key_value_seq_length) + attn_output, attn_output_weights = self.mha( hidden_states, - key_value_states, - key_value_states, + encoder_hidden_states, + encoder_hidden_states, attn_mask=attention_mask, need_weights=output_attentions, average_attn_weights=False) past_key_value = None return attn_output, attn_output_weights, past_key_value + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + + def set_bias(value): + bias_name = f'{prefix}mha.in_proj_bias' + if bias_name in state_dict: + state_dict[bias_name] += value + else: + state_dict[bias_name] = value + + def set_weight(value): + weight_name = f'{prefix}mha.in_proj_weight' + if weight_name in state_dict: + state_dict[weight_name] += value + else: + state_dict[weight_name] = value + + embed_dim = self.mha.embed_dim + for name, value in list(state_dict.items()): + if prefix + 'query.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[:embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'key.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[embed_dim:2 * embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'value.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[2 * embed_dim:3 * embed_dim] = value + set_weight(weight) + del state_dict[name] + if prefix + 'query.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[:embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'key.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[embed_dim:2 * embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'value.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[2 * embed_dim:3 * embed_dim] = value + set_bias(bias) + del state_dict[name] + state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0]) + state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape) + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)