Skip to content

Commit

Permalink
[Grok1] fix the name of input scale factor for autofp8 run (#183)
Browse files Browse the repository at this point in the history
Co-authored-by: wunhuang <[email protected]>
  • Loading branch information
kkHuang-amd and wunhuang authored Sep 13, 2024
1 parent 164ce38 commit 72d0cfb
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Grok1 model."""
import os
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -38,6 +37,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand All @@ -54,7 +54,6 @@
attn_output_multiplier = 0.08838834764831845
output_multiplier_scale = 0.5773502691896257
max_attn_val = 30.0
reduce_conversion_kernel: bool = os.getenv("VLLM_FP8_REDUCE_CONV", '0') == "1"


class Grok1MoE(nn.Module):
Expand Down Expand Up @@ -199,6 +198,7 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.use_fp8 = isinstance(quant_config, Fp8Config)
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = Grok1Attention(hidden_size=self.hidden_size,
Expand Down Expand Up @@ -233,18 +233,16 @@ def forward(
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
scale = None if not self.use_fp8 else \
self.attn.qkv_proj.input_scale
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.pre_attn_norm(
hidden_states, self.attn.qkv_proj.activation_scaling_factor
) if reduce_conversion_kernel else self.pre_attn_norm(
hidden_states)
hidden_states = self.pre_attn_norm(hidden_states, None, scale)
else:
hidden_states, residual = self.pre_attn_norm(
hidden_states, self.attn.qkv_proj.activation_scaling_factor,
residual) if reduce_conversion_kernel else self.pre_attn_norm(
hidden_states, residual)
hidden_states, residual, scale)

hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
Expand Down

0 comments on commit 72d0cfb

Please sign in to comment.