diff --git a/README.md b/README.md index 2e2b44b..d7e0a6c 100644 --- a/README.md +++ b/README.md @@ -18,27 +18,19 @@ pip install -e . ## Usage ### quantize model First add a config file named "quant_config.json" to model path. -For Baichuan or Llama model, config should be like: +For currenttly supported models, config should be like: ```json { - "qkv_proj": "per-tensor", - "o_proj": "per-tensor", - "gate_up_proj": "per-tensor", - "down_proj": "per-tensor" -} -``` - -As for Opt model, config should be like: - -```json -{ - "qkv_proj": "per-tensor", - "o_proj": "per-tensor", + "qkv": "per-tensor", + "out": "per-tensor", "fc1": "per-tensor", "fc2": "per-tensor" } ``` + +"qkv" stands for QKV matmul of attention, "out" stands for out matmul of attention. +"fc1" and "fc2" are the layers of the FFNs, which might be referred to as "gate_up" and "down" in Llama-like models. You can set the value to "per-tensor" or "per-token" to perform the quant granularity you want. Once config is set, generate scales and do model quantization with following command: @@ -72,10 +64,24 @@ Model support list: | ---------| ----------------------------| | LLaMA-2 | 7B/13B/70B | | LLaMA | 7B/13B/30B/65B | -| Mistral | Soon | -| OPT | 6.7B/13B/30B | -| Baichuan-2 | 13B (7B Soon) | -| Baichuan | 13B (7B Soon) | +| Mixtral | 8*7B | +| OPT | 6.7B/13B/30B | +| Baichuan-2 | 7B/13B | +| Baichuan | 7B/13B | + +## Performance and inference efficency +Detailed data comming soon + +Cases: + +[codellama-13b with A40](https://github.com/vllm-project/vllm/pull/1508#issuecomment-1824133140). Tested with vLLM + +[llama-13b with A100](https://github.com/vllm-project/vllm/pull/1508#issuecomment-1853826414). Tested with vLLM + + + + + ## Reference If you find SmoothQuant useful or relevant to your research, please cite their paper: diff --git a/autosmoothquant/examples/smoothquant_model.py b/autosmoothquant/examples/smoothquant_model.py index fb53ba7..d903aea 100644 --- a/autosmoothquant/examples/smoothquant_model.py +++ b/autosmoothquant/examples/smoothquant_model.py @@ -28,7 +28,7 @@ def parse_args(): help='where to save the act scales, activate when generating scales') parser.add_argument("--scale-input", type=str, default='scales/llama-13b', help='where to save the act scales, activate when quantizing models') - parser.add_argument('--num-samples', type=int, default=4) + parser.add_argument('--num-samples', type=int, default=512) parser.add_argument('--seq-len', type=int, default=512) parser.add_argument("--model-output", type=str, default='quantized_model/llama-13b', help='where to save the quantized models, activate when quantizing models') @@ -114,4 +114,4 @@ def main(): int8_model.save_pretrained(output_path) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/autosmoothquant/models/baichuan.py b/autosmoothquant/models/baichuan.py index fcd0c5a..b76cd78 100644 --- a/autosmoothquant/models/baichuan.py +++ b/autosmoothquant/models/baichuan.py @@ -3,11 +3,13 @@ from torch import nn from transformers.activations import ACT2FN from transformers.utils import logging -from typing import Optional, Tuple, List +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from typing import Optional, Tuple, List, Union from autosmoothquant.layers.nn.linear import W8A8BFP32OFP32LinearWithQuantScale, W8A8BFP32OFP32Linear, W8A8BFP32OFP32QKVLinear from autosmoothquant.thirdparty.baichuan.modeling_baichuan import ( RMSNorm, + RotaryEmbedding, MLP, BaichuanAttention, BaichuanLayer, @@ -28,23 +30,23 @@ "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers." ) -class Int8BaichuanRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False)) - self.epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + - # convert into half-precision - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) +def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids): + cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin) + k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) - out = self.weight * hidden_states - int8_out = out.round().clamp(-128, 127).to(torch.int8) - return int8_out +class Int8BaichuanRMSNorm(RMSNorm): @staticmethod def from_float(module: RMSNorm, @@ -52,23 +54,18 @@ def from_float(module: RMSNorm, int8_norm = Int8BaichuanRMSNorm(module.weight.numel(), module.epsilon) int8_norm.weight.to(module.weight.dtype) - int8_norm.weight = module.weight / output_scale + int8_norm.weight = torch.nn.Parameter(module.weight / output_scale) return int8_norm -_RMSNorm = { - "per-tensor": Int8BaichuanRMSNorm, - "per-token": RMSNorm -} - # attention is the same as opt class Int8BaichuanAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, config: BaichuanConfig, - position_embedding: str, - quant_config: dict[str, str] + quant_config: dict[str, str], + position_embedding: str ): super().__init__() self.config = config @@ -82,32 +79,15 @@ def __init__( raise ValueError( f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" ) - self.qkv_quant_type = quant_config["qkv_proj"] - self.o_quant_type = quant_config["o_proj"] + self.qkv_quant_type = quant_config["qkv"] + self.o_quant_type = quant_config["out"] self.qkv_size = [self.num_heads * self.head_dim] * 3 self.W_pack = W8A8BFP32OFP32QKVLinear(self.qkv_size, self.hidden_size, 3 * self.num_heads * self.head_dim, act_quant=self.qkv_quant_type) self.o_proj = W8A8BFP32OFP32LinearWithQuantScale(self.num_heads * self.head_dim, self.hidden_size, act_quant=self.o_quant_type) - if self.postion_embedding == "ALIBI": - alibi_slopes = _get_alibi_slopes(self.total_num_heads) - alibi_slopes = alibi_slopes[head_start:head_end].tolist() + if self.position_embedding == "ROPE": + self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) - else: - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - ) - self.scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, self.head_dim, - self.scaling) - _shape = BaichuanAttention._shape @staticmethod @@ -115,10 +95,11 @@ def __init__( def from_float(module: BaichuanAttention, config: BaichuanConfig, quant_config: dict[str, str], + position_embedding: str, attn_input_scale: float, attn_output_scale: float, out_input_scale: float): - int8_module = Int8BaichuanAttention(config, quant_config) + int8_module = Int8BaichuanAttention(config, quant_config, position_embedding) # we do not impelement attn for now bacuase we want to use paged attention int8_module.W_pack = W8A8BFP32OFP32QKVLinear.from_float( module.W_pack, attn_input_scale, int8_module.qkv_size, act_quant=int8_module.qkv_quant_type) @@ -158,6 +139,10 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + if self.position_embedding == "ROPE": + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) @@ -166,12 +151,6 @@ def forward( past_key_value = (key_states, value_states) if use_cache else None if xops is not None and self.training: attn_weights = None - # query_states = query_states.transpose(1, 2) - # key_states = key_states.transpose(1, 2) - # value_states = value_states.transpose(1, 2) - # attn_output = xops.memory_efficient_attention( - # query_states, key_states, value_states, attn_bias=attention_mask - # ) with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask) attn_output = attn_output.transpose(1, 2) @@ -191,7 +170,7 @@ def forward( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) ) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) @@ -214,8 +193,8 @@ def __init__( super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gate_up_quant_type = quant_config["gate_up_proj"] - self.down_quant_type = quant_config["down_proj"] + self.gate_up_quant_type = quant_config["fc1"] + self.down_quant_type = quant_config["fc2"] self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size, act_quant=self.gate_up_quant_type) self.down_proj = W8A8BFP32OFP32LinearWithQuantScale(self.intermediate_size, self.hidden_size, act_quant=self.down_quant_type) self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size, act_quant=self.gate_up_quant_type) @@ -250,27 +229,28 @@ def forward(self, x): class Int8BaichuanLayer(nn.Module): - def __init__(self, config: BaichuanConfig, position_embedding: str, quant_config: dict[str, str]): + def __init__(self, config: BaichuanConfig, quant_config: dict[str, str], position_embedding: str): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Int8BaichuanAttention(config=config, position_embedding, quant_config=quant_config) + self.self_attn = Int8BaichuanAttention( + config=config, + position_embedding=position_embedding, + quant_config=quant_config + ) self.mlp = Int8BaichuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config ) - input_layernorm_cls = _RMSNorm[quant_config["qkv_proj"]] - post_attention_layernorm_cls = _RMSNorm[quant_config["gate_up_proj"]] - self.input_layernorm = input_layernorm_cls(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = post_attention_layernorm_cls( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = Int8BaichuanRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Int8BaichuanRMSNorm(config.hidden_size, config.rms_norm_eps) @staticmethod def from_float(module: BaichuanLayer, config: BaichuanConfig, quant_config: dict[str, str], + position_embedding: str, attn_input_scale: float, attn_output_scale: float, out_input_scale: float, @@ -279,13 +259,15 @@ def from_float(module: BaichuanLayer, ): int8_module = Int8BaichuanLayer( config, - quant_config + quant_config, + position_embedding ) int8_module.self_attn = Int8BaichuanAttention.from_float( module.self_attn, config, quant_config, + position_embedding, attn_input_scale, attn_output_scale, out_input_scale @@ -298,14 +280,14 @@ def from_float(module: BaichuanLayer, gate_input_scale, down_input_scale ) - if quant_config["qkv_proj"] == "per-tensor": + if quant_config["qkv"] == "per-tensor": int8_module.input_layernorm = Int8BaichuanRMSNorm.from_float( module.input_layernorm, attn_input_scale ) else: int8_module.input_layernorm = module.input_layernorm - if quant_config["gate_up_proj"] == "per-tensor": + if quant_config["fc1"] == "per-tensor": int8_module.post_attention_layernorm = Int8BaichuanRMSNorm.from_float( module.post_attention_layernorm, gate_input_scale @@ -359,7 +341,7 @@ def __init__(self, config: BaichuanConfig, position_embedding: str, quant_config config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = torch.nn.ModuleList( - [Int8BaichuanLayer(config, position_embedding, quant_config) for _ in range(config.num_hidden_layers)] + [Int8BaichuanLayer(config, quant_config, position_embedding) for _ in range(config.num_hidden_layers)] ) self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) @@ -368,51 +350,216 @@ def __init__(self, config: BaichuanConfig, position_embedding: str, quant_config self.max_cache_pos = config.model_max_length self.first_run = True self.alibi_mask = None + self.position_embedding = position_embedding @staticmethod - def from_float(module, decoder_layer_scales, quant_config): - int8_module = Int8BaichuanModel(module.config, quant_config) + def from_float(module, decoder_layer_scales, quant_config, position_embedding): + int8_module = Int8BaichuanModel(module.config, position_embedding, quant_config) int8_module.embed_tokens = module.embed_tokens int8_module.norm = module.norm for i, layer in enumerate(module.layers): int8_module.layers[i] = Int8BaichuanLayer.from_float( - layer, module.config, quant_config, **decoder_layer_scales[i]) + layer, module.config, quant_config, position_embedding, **decoder_layer_scales[i]) return int8_module + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot provide both input_ids and inputs_embeds simultaneously" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You need to provide input_ids or inputs_embeds") + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + seq_length_with_past = seq_length + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # baichuan 13b use alibi + if self.position_embedding == "ALIBI": + alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) + + if attention_mask is not None: + if len(attention_mask.shape) == 2: + expanded_mask = attention_mask.to(alibi_mask.dtype) + expanded_mask = torch.tril( + torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) + ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) + else: + expanded_mask = attention_mask + bsz = inputs_embeds.size(0) + src_len, tgt_len = alibi_mask.size()[-2:] + expanded_mask = ( + expanded_mask.unsqueeze(1) + .expand(bsz, 1, src_len, tgt_len) + .to(alibi_mask.dtype) + ) + inverted_mask = 1.0 - expanded_mask + inverted_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min + ) + attention_mask = inverted_mask + alibi_mask.unsqueeze(0) + else: + attention_mask = alibi_mask + else: + # baichuan 7b use rope + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_rope_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_rope_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + get_input_embeddings = BaichuanModel.get_input_embeddings set_input_embeddings = BaichuanModel.set_input_embeddings get_alibi_mask = BaichuanModel.get_alibi_mask - forward = BaichuanModel.forward + -class Int8BaichuanBaseForCausalLM(BaichuanPreTrainedModel): +class Int8BaichuanForCausalLM(BaichuanPreTrainedModel): def __init__(self, config, - position_embedding: str, - quant_config, + quant_config: dict[str, str], *model_args, **model_kwargs): super().__init__(config, *model_args, **model_kwargs) + if config.hidden_size == 4096: # 7b + self.position_embedding = "ROPE" + else: # 13b + self.position_embedding = "ALIBI" self.model = Int8BaichuanModel(config, - position_embedding, + self.position_embedding, quant_config) self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) - if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False): - try: - from .quantizer import quantize_offline, init_model_weight_int4 - except ImportError: - raise ImportError(f"Needs quantize_offline to run quantize.") - quantize_offline(self, 4) # Initialize weights and apply final processing self.post_init() @staticmethod - def from_float(module, decoder_layer_scales, quant_config): + def from_float(module, decoder_layer_scales, quant_config: dict[str, str]): int8_module = Int8BaichuanForCausalLM(module.config, quant_config) print("start trans into int8, this might take a while") + if module.config.hidden_size == 4096: + position_embedding = "ROPE" + else: + position_embedding = "ALIBI" int8_module.model = Int8BaichuanModel.from_float( - module.model, decoder_layer_scales, quant_config) + module.model, decoder_layer_scales, quant_config, position_embedding) int8_module.lm_head = module.lm_head return int8_module @@ -424,13 +571,3 @@ def from_float(module, decoder_layer_scales, quant_config): get_decoder = BaichuanForCausalLM.get_decoder forward = BaichuanForCausalLM.forward prepare_inputs_for_generation = BaichuanForCausalLM.prepare_inputs_for_generation - -class Int8BaichuanForCausalLM(Int8BaiChuanBaseForCausalLM): - - def __init__(self, - config, - linear_method: Optional[LinearMethodBase] = None): - if config.hidden_size == 4096: # 7b - super().__init__(config, "ROPE", linear_method) - else: # 13b - super().__init__(config, "ALIBI", linear_method) diff --git a/autosmoothquant/models/llama.py b/autosmoothquant/models/llama.py index 918ba52..c38556d 100644 --- a/autosmoothquant/models/llama.py +++ b/autosmoothquant/models/llama.py @@ -62,8 +62,8 @@ def __init__( f" and `num_heads`: {self.num_heads})." ) - self.qkv_quant_type = quant_config["qkv_proj"] - self.o_quant_type = quant_config["o_proj"] + self.qkv_quant_type = quant_config["qkv"] + self.o_quant_type = quant_config["out"] self.k_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, act_quant=self.qkv_quant_type) self.v_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, act_quant=self.qkv_quant_type) self.q_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, act_quant=self.qkv_quant_type) @@ -99,8 +99,8 @@ def __init__(self, config: LlamaConfig, quant_config: dict[str, str]): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_quant_type = quant_config["gate_up_proj"] - self.down_quant_type = quant_config["down_proj"] + self.gate_up_quant_type = quant_config["fc1"] + self.down_quant_type = quant_config["fc2"] self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size, act_quant=self.gate_up_quant_type) self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size, act_quant=self.gate_up_quant_type) self.down_proj = W8A8BFP32OFP32LinearWithQuantScale(self.intermediate_size, self.hidden_size, act_quant=self.down_quant_type) @@ -170,14 +170,14 @@ def from_float(module: LlamaDecoderLayer, gate_input_scale, down_input_scale ) - if quant_config["qkv_proj"] == "per-tensor": + if quant_config["qkv"] == "per-tensor": int8_module.input_layernorm = Int8LlamaRMSNorm.from_float( module.input_layernorm, attn_input_scale ) else: int8_module.input_layernorm = module.input_layernorm - if quant_config["gate_up_proj"] == "per-tensor": + if quant_config["fc1"] == "per-tensor": int8_module.post_attention_layernorm = Int8LlamaRMSNorm.from_float( module.post_attention_layernorm, gate_input_scale diff --git a/autosmoothquant/models/mixtral.py b/autosmoothquant/models/mixtral.py index 9edd52f..7891abd 100644 --- a/autosmoothquant/models/mixtral.py +++ b/autosmoothquant/models/mixtral.py @@ -96,9 +96,9 @@ def __init__(self, config: MixtralConfig, quant_config: dict[str, str]): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size - self.w1 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["w1"]) - self.w2 = W8A8BFP32OFP32LinearWithQuantScale(self.ffn_dim, self.hidden_dim, act_quant=quant_config["w2"]) - self.w3 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["w1"]) + self.w1 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["fc1"]) + self.w2 = W8A8BFP32OFP32LinearWithQuantScale(self.ffn_dim, self.hidden_dim, act_quant=quant_config["fc2"]) + self.w3 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["fc1"]) self.act_fn = ACT2FN[config.hidden_act] @@ -111,9 +111,9 @@ def from_float(module: MixtralBLockSparseTop2MLP, moe_input_scale: float, down_input_scale: float): int8_module = Int8MixtralBLockSparseTop2MLP(config, quant_config) - int8_module.w1 = W8A8BFP32OFP32Linear.from_float(module.w1, moe_input_scale, act_quant=quant_config["w1"]) - int8_module.w2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(module.w2, down_input_scale, act_quant=quant_config["w2"]) - int8_module.w3 = W8A8BFP32OFP32Linear.from_float(module.w3, moe_input_scale, act_quant=quant_config["w1"]) + int8_module.w1 = W8A8BFP32OFP32Linear.from_float(module.w1, moe_input_scale, act_quant=quant_config["fc1"]) + int8_module.w2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(module.w2, down_input_scale, act_quant=quant_config["fc2"]) + int8_module.w3 = W8A8BFP32OFP32Linear.from_float(module.w3, moe_input_scale, act_quant=quant_config["fc1"]) return int8_module @@ -150,7 +150,6 @@ def from_float(module: MixtralSparseMoeBlock, moe_input_scale: float, down_input_scales: List[float]): int8_module = Int8MixtralSparseMoeBlock(config, quant_config) - # int8_module.gate = W8A8BFP32OFP32Linear.from_float(module.gate, moe_input_scale, act_quant=quant_config["w1"]) int8_module.gate = module.gate for i, expert in enumerate(module.experts): int8_module.experts[i] = Int8MixtralBLockSparseTop2MLP.from_float( @@ -214,7 +213,7 @@ def from_float(module: MixtralDecoderLayer, ) else: int8_module.input_layernorm = module.input_layernorm - if quant_config["w1"] == "per-tensor": + if quant_config["fc1"] == "per-tensor": int8_module.post_attention_layernorm = Int8MixtralRMSNorm.from_float( module.post_attention_layernorm, moe_input_scale diff --git a/autosmoothquant/models/opt.py b/autosmoothquant/models/opt.py index 87ca261..af1b2ac 100644 --- a/autosmoothquant/models/opt.py +++ b/autosmoothquant/models/opt.py @@ -73,8 +73,8 @@ def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs): self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.qkv_quant_type = quant_config["qkv_proj"] - self.o_quant_type = quant_config["o_proj"] + self.qkv_quant_type = quant_config["qkv"] + self.o_quant_type = quant_config["out"] self.k_proj = W8A8BFP32OFP32Linear(self.embed_dim, self.embed_dim, use_bias=self.enable_bias, act_quant=self.qkv_quant_type) self.v_proj = W8A8BFP32OFP32Linear(self.embed_dim, self.embed_dim, use_bias=self.enable_bias, act_quant=self.qkv_quant_type) self.q_proj = W8A8BFP32OFP32Linear(self.embed_dim, self.embed_dim, use_bias=self.enable_bias, act_quant=self.qkv_quant_type) @@ -153,7 +153,7 @@ def from_float(module: OPTDecoderLayer, module.fc1, fc1_input_scale, act_quant=int8_module.fc1_quant_type) int8_module.fc2 = W8A8BFP32OFP32LinearWithQuantScale.from_float( module.fc2, fc2_input_scale, act_quant=int8_module.fc2_quant_type) - if quant_config["qkv_proj"] == "per-tensor": + if quant_config["qkv"] == "per-tensor": int8_module.self_attn_layer_norm = Int8OPTLayerNorm.from_float( module.self_attn_layer_norm, attn_input_scale) else: diff --git a/autosmoothquant/quantize/calibration.py b/autosmoothquant/quantize/calibration.py index fe346d9..dccedf2 100644 --- a/autosmoothquant/quantize/calibration.py +++ b/autosmoothquant/quantize/calibration.py @@ -19,7 +19,8 @@ def _model_preprocess(model): original_top_k = model.model.layers[0].block_sparse_moe.top_k num_local_experts = getattr(model.config, "num_local_experts") info_dict["original_top_k"] = original_top_k - # To get all expert act scales, we set top_k to the number of total experts here. + #FIXME: To get all expert act scales, we set top_k to the number of total experts + # which might have negative effects on generating sclaes for layer in model.model.layers: layer.block_sparse_moe.top_k = num_local_experts return info_dict diff --git a/autosmoothquant/thirdparty/baichuan/modeling_baichuan.py b/autosmoothquant/thirdparty/baichuan/modeling_baichuan.py index 9cc30f6..593da21 100644 --- a/autosmoothquant/thirdparty/baichuan/modeling_baichuan.py +++ b/autosmoothquant/thirdparty/baichuan/modeling_baichuan.py @@ -115,25 +115,6 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...], ) - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids): - cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin) - k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) - return q_embed.to(q.dtype), k_embed.to(k.dtype) - - - ##baichuan 13B def _get_interleave(n): @@ -642,13 +623,6 @@ def __init__(self, config, *model_args, **model_kwargs): super().__init__(config, *model_args, **model_kwargs) self.model = BaichuanModel(config) self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) - #if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: - if hasattr(config, "quantization_config") and isinstance(config.quantization_config, dict) and config.quantization_config.get('load_in_4bit', False): - try: - from .quantizer import quantize_offline, init_model_weight_int4 - except ImportError: - raise ImportError(f"Needs quantize_offline to run quantize.") - quantize_offline(self, 4) # Initialize weights and apply final processing self.post_init() @@ -706,70 +680,6 @@ def from_pretrained( ) else: model_kwargs = kwargs - - if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: - try: - from .quantizer import init_model_weight_int4 - from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map - from accelerate.utils import CustomDtype - from accelerate.utils import get_balanced_memory - except ImportError: - raise ImportError(f"Needs import model weight init func to run quantize.") - # Instantiate model. - init_contexts = [no_init_weights(_enable=True)] - init_contexts.append(init_empty_weights()) - with ContextManagers(init_contexts): - model = cls(config) - - model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin') - state_dict = torch.load(model_file, map_location="cpu") - model.is_quantized = True - - device_map = kwargs.pop("device_map", None) - torch_dtype = kwargs.pop("torch_dtype", None) - if device_map is not None: - kwargs = {"no_split_module_classes": model._no_split_modules} - target_dtype = CustomDtype.INT4 - max_memory = get_balanced_memory( - model, - dtype=target_dtype, - low_zero=(device_map == "balanced_low_0"), - max_memory=None, - **kwargs, - ) - kwargs["max_memory"] = max_memory - device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) - model = init_model_weight_int4(config, model, state_dict) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=False, - proxies=None, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder="", - _from_auto=False, - _from_pipeline=None, - **kwargs, - ) - except (OSError, TypeError): - logger.info( - "Generation config file not found, using a generation config created from the model config." - ) - pass - - if device_map is not None: - dispatch_model(model, device_map=device_map) - - return model return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, @@ -833,13 +743,6 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - def quantize(self, bits: int): - try: - from .quantizer import quantize_online - except ImportError: - raise ImportError(f"Needs QLinear to run quantize.") - return quantize_online(self, bits) def prepare_inputs_for_generation( self, diff --git a/requirements.txt b/requirements.txt index ff99a29..d1e9fec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers == 4.34.0 +transformers == 4.36.2 datasets accelerate icecream \ No newline at end of file