From 194d32ade5528987e351296264601eee00e72279 Mon Sep 17 00:00:00 2001 From: wunhuang Date: Fri, 13 Sep 2024 16:55:43 +0000 Subject: [PATCH] [Grok1] fix the name of input scale factor for autofp8 run --- vllm/model_executor/models/grok1.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index a2bc0a8c792a0..33173072a5df4 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -21,7 +21,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Grok1 model.""" -import os from typing import Iterable, List, Optional, Tuple import torch @@ -38,6 +37,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -54,7 +54,6 @@ attn_output_multiplier = 0.08838834764831845 output_multiplier_scale = 0.5773502691896257 max_attn_val = 30.0 -reduce_conversion_kernel: bool = os.getenv("VLLM_FP8_REDUCE_CONV", '0') == "1" class Grok1MoE(nn.Module): @@ -199,6 +198,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_fp8 = isinstance(quant_config, Fp8Config) # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.attn = Grok1Attention(hidden_size=self.hidden_size, @@ -233,18 +233,16 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> torch.Tensor: + scale = None if not self.use_fp8 else \ + self.attn.qkv_proj.input_scale # Self Attention if residual is None: residual = hidden_states - hidden_states = self.pre_attn_norm( - hidden_states, self.attn.qkv_proj.activation_scaling_factor - ) if reduce_conversion_kernel else self.pre_attn_norm( - hidden_states) + hidden_states = self.pre_attn_norm(hidden_states, None, scale) else: hidden_states, residual = self.pre_attn_norm( - hidden_states, self.attn.qkv_proj.activation_scaling_factor, - residual) if reduce_conversion_kernel else self.pre_attn_norm( - hidden_states, residual) + hidden_states, residual, scale) + hidden_states = self.attn( positions=positions, hidden_states=hidden_states,