diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index cf694d4eb..151d86a84 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -76,6 +76,41 @@ def num_heads(self): def batch_first(self): return self.wrapped_mha.batch_first + +class QuantizableOPTAttention(MultiheadAttentionWrapper): + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if key_value_states is None: + key_value_states = hidden_states + if layer_head_mask is not None: + raise RuntimeError("layer_head_mask is not supported.") + if self.mha.batch_first: + batch_size, query_seq_length = hidden_states.shape[:2] + key_value_seq_length = key_value_states.shape[1] + else: + query_seq_length, batch_size = hidden_states.shape[:2] + key_value_seq_length = key_value_states.shape[0] + num_heads = self.mha.num_heads + attention_mask = attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + attn_output, attn_output_weights = self.mha( + hidden_states, + key_value_states, + key_value_states, + attn_mask=attention_mask, + need_weights=output_attentions, + average_attn_weights=False) + past_key_value = None + return attn_output, attn_output_weights, past_key_value + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -142,36 +177,153 @@ def set_weight(value): state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -class QuantizableOPTAttention(MultiheadAttentionWrapper): +class QuantizableBertAttention(MultiheadAttentionWrapper): + + def __init__( + self, + all_head_size, + num_attention_heads, + ln_normalized_shape, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + ln_eps=1e-05, + ln_elementwise_affine=True, + ln_bias=True, + device=None, + dtype=None) -> None: + super().__init__( + embed_dim=all_head_size, + num_heads=num_attention_heads, + dropout=dropout, + bias=bias, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + kdim=kdim, + vdim=vdim, + batch_first=batch_first, + device=device, + dtype=dtype) + self.ln = nn.LayerNorm( + normalized_shape=ln_normalized_shape, + eps=ln_eps, + elementwise_affine=ln_elementwise_affine, + bias=ln_bias, + device=device, + dtype=dtype) def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if key_value_states is None: - key_value_states = hidden_states - if layer_head_mask is not None: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + if encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + if head_mask is not None: raise RuntimeError("layer_head_mask is not supported.") - if self.batch_first: + + if self.mha.batch_first: batch_size, query_seq_length = hidden_states.shape[:2] - key_value_seq_length = key_value_states.shape[1] + key_value_seq_length = encoder_hidden_states.shape[1] else: query_seq_length, batch_size = hidden_states.shape[:2] - key_value_seq_length = key_value_states.shape[0] - num_heads = self.num_heads + key_value_seq_length = encoder_hidden_states.shape[0] + num_heads = self.mha.num_heads attention_mask = attention_mask_handler( attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) attn_output, attn_output_weights = self.mha( hidden_states, - key_value_states, - key_value_states, + encoder_hidden_states, + encoder_hidden_states, attn_mask=attention_mask, need_weights=output_attentions, average_attn_weights=False) + ln_output = self.ln(attn_output + hidden_states) past_key_value = None - return attn_output, attn_output_weights, past_key_value + return ln_output, attn_output_weights, past_key_value + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + + def set_bias(value): + bias_name = f'{prefix}mha.in_proj_bias' + if bias_name in state_dict: + state_dict[bias_name] += value + else: + state_dict[bias_name] = value + + def set_weight(value): + weight_name = f'{prefix}mha.in_proj_weight' + if weight_name in state_dict: + state_dict[weight_name] += value + else: + state_dict[weight_name] = value + + embed_dim = self.mha.embed_dim + for name, value in list(state_dict.items()): + if prefix + 'self.query.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[:embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'self.key.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[embed_dim:2 * embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'self.value.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[2 * embed_dim:3 * embed_dim] = value + set_weight(weight) + del state_dict[name] + if prefix + 'self.query.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[:embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'self.key.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[embed_dim:2 * embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'self.value.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[2 * embed_dim:3 * embed_dim] = value + set_bias(bias) + del state_dict[name] + if prefix + 'output.dense.weight' in name: + weight_name = f'{prefix}mha.out_proj.weight' + state_dict[weight_name] = value + del state_dict[name] + if prefix + 'output.dense.bias' in name: + weight_name = f'{prefix}mha.out_proj.bias' + state_dict[weight_name] = value + del state_dict[name] + if prefix + 'output.LayerNorm.weight' in name: + weight_name = f'{prefix}ln.weight' + state_dict[weight_name] = value + del state_dict[name] + if prefix + 'output.LayerNorm.bias' in name: + weight_name = f'{prefix}ln.bias' + state_dict[weight_name] = value + del state_dict[name] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index 2a9505227..b29c83809 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,11 +1,38 @@ import warnings +from transformers.models.bert.modeling_bert import BertAttention from transformers.models.opt.modeling_opt import OPTAttention from brevitas.graph import ModuleToModuleByClass +from brevitas_examples.llm.llm_quant.mha_layers import QuantizableBertAttention from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention -QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})} +QUANTIZABLE_MHA_MAP = { + OPTAttention: (QuantizableOPTAttention, { + 'batch_first': True}), + BertAttention: (QuantizableBertAttention, { + 'batch_first': True}),} + + +def _set_bert_mha_attributes(module): + module.all_head_size = module._modules['self'].all_head_size + module.num_attention_heads = module._modules['self'].num_attention_heads + module.ln_normalized_shape = module._modules['output'].LayerNorm.normalized_shape + module.ln_eps = module._modules['output'].LayerNorm.eps + module.ln_elementwise_affine = module._modules['output'].LayerNorm.elementwise_affine + module.ln_bias = False if module._modules['output'].LayerNorm.bias is None else True + + +_SET_ATTRIBUTES_MAP = { + BertAttention: _set_bert_mha_attributes,} + + +def set_mha_attributes(model): + for name, module in model.named_modules(): + mod_type = type(module) + if mod_type in _SET_ATTRIBUTES_MAP.keys(): + _SET_ATTRIBUTES_MAP[mod_type](module) + return model def replace_mha_with_quantizable_layers(model, dtype):