diff --git a/autosmoothquant/models/baichuan.py b/autosmoothquant/models/baichuan.py index c342b7e..b76cd78 100644 --- a/autosmoothquant/models/baichuan.py +++ b/autosmoothquant/models/baichuan.py @@ -46,23 +46,7 @@ def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids): k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) return q_embed.to(q.dtype), k_embed.to(k.dtype) -class Int8BaichuanRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False)) - self.epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) - - # convert into half-precision - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - out = self.weight * hidden_states - int8_out = out.round().clamp(-128, 127).to(torch.int8) - return int8_out +class Int8BaichuanRMSNorm(RMSNorm): @staticmethod def from_float(module: RMSNorm, @@ -70,15 +54,10 @@ def from_float(module: RMSNorm, int8_norm = Int8BaichuanRMSNorm(module.weight.numel(), module.epsilon) int8_norm.weight.to(module.weight.dtype) - int8_norm.weight = module.weight / output_scale + int8_norm.weight = torch.nn.Parameter(module.weight / output_scale) return int8_norm -_RMSNorm = { - "per-tensor": Int8BaichuanRMSNorm, - "per-token": RMSNorm -} - # attention is the same as opt class Int8BaichuanAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -191,7 +170,7 @@ def forward( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) ) - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) @@ -264,12 +243,8 @@ def __init__(self, config: BaichuanConfig, quant_config: dict[str, str], positio hidden_act=config.hidden_act, quant_config=quant_config ) - input_layernorm_cls = _RMSNorm[quant_config["qkv"]] - post_attention_layernorm_cls = _RMSNorm[quant_config["fc1"]] - self.input_layernorm = input_layernorm_cls(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = post_attention_layernorm_cls( - config.hidden_size, eps=config.rms_norm_eps - ) + self.input_layernorm = Int8BaichuanRMSNorm(config.hidden_size, config.rms_norm_eps) + self.post_attention_layernorm = Int8BaichuanRMSNorm(config.hidden_size, config.rms_norm_eps) @staticmethod def from_float(module: BaichuanLayer, diff --git a/autosmoothquant/models/mixtral.py b/autosmoothquant/models/mixtral.py index 9edd52f..7891abd 100644 --- a/autosmoothquant/models/mixtral.py +++ b/autosmoothquant/models/mixtral.py @@ -96,9 +96,9 @@ def __init__(self, config: MixtralConfig, quant_config: dict[str, str]): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size - self.w1 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["w1"]) - self.w2 = W8A8BFP32OFP32LinearWithQuantScale(self.ffn_dim, self.hidden_dim, act_quant=quant_config["w2"]) - self.w3 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["w1"]) + self.w1 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["fc1"]) + self.w2 = W8A8BFP32OFP32LinearWithQuantScale(self.ffn_dim, self.hidden_dim, act_quant=quant_config["fc2"]) + self.w3 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["fc1"]) self.act_fn = ACT2FN[config.hidden_act] @@ -111,9 +111,9 @@ def from_float(module: MixtralBLockSparseTop2MLP, moe_input_scale: float, down_input_scale: float): int8_module = Int8MixtralBLockSparseTop2MLP(config, quant_config) - int8_module.w1 = W8A8BFP32OFP32Linear.from_float(module.w1, moe_input_scale, act_quant=quant_config["w1"]) - int8_module.w2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(module.w2, down_input_scale, act_quant=quant_config["w2"]) - int8_module.w3 = W8A8BFP32OFP32Linear.from_float(module.w3, moe_input_scale, act_quant=quant_config["w1"]) + int8_module.w1 = W8A8BFP32OFP32Linear.from_float(module.w1, moe_input_scale, act_quant=quant_config["fc1"]) + int8_module.w2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(module.w2, down_input_scale, act_quant=quant_config["fc2"]) + int8_module.w3 = W8A8BFP32OFP32Linear.from_float(module.w3, moe_input_scale, act_quant=quant_config["fc1"]) return int8_module @@ -150,7 +150,6 @@ def from_float(module: MixtralSparseMoeBlock, moe_input_scale: float, down_input_scales: List[float]): int8_module = Int8MixtralSparseMoeBlock(config, quant_config) - # int8_module.gate = W8A8BFP32OFP32Linear.from_float(module.gate, moe_input_scale, act_quant=quant_config["w1"]) int8_module.gate = module.gate for i, expert in enumerate(module.experts): int8_module.experts[i] = Int8MixtralBLockSparseTop2MLP.from_float( @@ -214,7 +213,7 @@ def from_float(module: MixtralDecoderLayer, ) else: int8_module.input_layernorm = module.input_layernorm - if quant_config["w1"] == "per-tensor": + if quant_config["fc1"] == "per-tensor": int8_module.post_attention_layernorm = Int8MixtralRMSNorm.from_float( module.post_attention_layernorm, moe_input_scale