Skip to content

Commit

Permalink
fix baichuan rms
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangpeng committed Feb 1, 2024
1 parent 1206daf commit 23f57a8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 38 deletions.
35 changes: 5 additions & 30 deletions autosmoothquant/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,39 +46,18 @@ def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)

class Int8BaichuanRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.register_buffer('weight', torch.ones(hidden_size, dtype=torch.float32, requires_grad=False))
self.epsilon = eps

def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)

# convert into half-precision
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

out = self.weight * hidden_states
int8_out = out.round().clamp(-128, 127).to(torch.int8)
return int8_out
class Int8BaichuanRMSNorm(RMSNorm):

@staticmethod
def from_float(module: RMSNorm,
output_scale: float):
int8_norm = Int8BaichuanRMSNorm(module.weight.numel(), module.epsilon)

int8_norm.weight.to(module.weight.dtype)
int8_norm.weight = module.weight / output_scale
int8_norm.weight = torch.nn.Parameter(module.weight / output_scale)

return int8_norm

_RMSNorm = {
"per-tensor": Int8BaichuanRMSNorm,
"per-token": RMSNorm
}

# attention is the same as opt
class Int8BaichuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
Expand Down Expand Up @@ -191,7 +170,7 @@ def forward(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
Expand Down Expand Up @@ -264,12 +243,8 @@ def __init__(self, config: BaichuanConfig, quant_config: dict[str, str], positio
hidden_act=config.hidden_act,
quant_config=quant_config
)
input_layernorm_cls = _RMSNorm[quant_config["qkv"]]
post_attention_layernorm_cls = _RMSNorm[quant_config["fc1"]]
self.input_layernorm = input_layernorm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = post_attention_layernorm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.input_layernorm = Int8BaichuanRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = Int8BaichuanRMSNorm(config.hidden_size, config.rms_norm_eps)

@staticmethod
def from_float(module: BaichuanLayer,
Expand Down
15 changes: 7 additions & 8 deletions autosmoothquant/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(self, config: MixtralConfig, quant_config: dict[str, str]):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["w1"])
self.w2 = W8A8BFP32OFP32LinearWithQuantScale(self.ffn_dim, self.hidden_dim, act_quant=quant_config["w2"])
self.w3 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["w1"])
self.w1 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["fc1"])
self.w2 = W8A8BFP32OFP32LinearWithQuantScale(self.ffn_dim, self.hidden_dim, act_quant=quant_config["fc2"])
self.w3 = W8A8BFP32OFP32Linear(self.hidden_dim, self.ffn_dim, act_quant=quant_config["fc1"])

self.act_fn = ACT2FN[config.hidden_act]

Expand All @@ -111,9 +111,9 @@ def from_float(module: MixtralBLockSparseTop2MLP,
moe_input_scale: float,
down_input_scale: float):
int8_module = Int8MixtralBLockSparseTop2MLP(config, quant_config)
int8_module.w1 = W8A8BFP32OFP32Linear.from_float(module.w1, moe_input_scale, act_quant=quant_config["w1"])
int8_module.w2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(module.w2, down_input_scale, act_quant=quant_config["w2"])
int8_module.w3 = W8A8BFP32OFP32Linear.from_float(module.w3, moe_input_scale, act_quant=quant_config["w1"])
int8_module.w1 = W8A8BFP32OFP32Linear.from_float(module.w1, moe_input_scale, act_quant=quant_config["fc1"])
int8_module.w2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(module.w2, down_input_scale, act_quant=quant_config["fc2"])
int8_module.w3 = W8A8BFP32OFP32Linear.from_float(module.w3, moe_input_scale, act_quant=quant_config["fc1"])
return int8_module


Expand Down Expand Up @@ -150,7 +150,6 @@ def from_float(module: MixtralSparseMoeBlock,
moe_input_scale: float,
down_input_scales: List[float]):
int8_module = Int8MixtralSparseMoeBlock(config, quant_config)
# int8_module.gate = W8A8BFP32OFP32Linear.from_float(module.gate, moe_input_scale, act_quant=quant_config["w1"])
int8_module.gate = module.gate
for i, expert in enumerate(module.experts):
int8_module.experts[i] = Int8MixtralBLockSparseTop2MLP.from_float(
Expand Down Expand Up @@ -214,7 +213,7 @@ def from_float(module: MixtralDecoderLayer,
)
else:
int8_module.input_layernorm = module.input_layernorm
if quant_config["w1"] == "per-tensor":
if quant_config["fc1"] == "per-tensor":
int8_module.post_attention_layernorm = Int8MixtralRMSNorm.from_float(
module.post_attention_layernorm,
moe_input_scale
Expand Down

0 comments on commit 23f57a8

Please sign in to comment.