diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 8d5dfebc4c03b..1964b934e1986 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -269,16 +269,18 @@ def rms_norm_dynamic_per_token_quant( return output, scales -@register_fake("_C::rms_norm_dynamic_per_token_quant") -def _rms_norm_dynamic_per_token_quant_fake( - output: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - epsilon: float, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None) -> None: - return None +if hasattr(torch.ops._C, "rms_norm_dynamic_per_token_quant"): + + @register_fake("_C::rms_norm_dynamic_per_token_quant") + def _rms_norm_dynamic_per_token_quant_fake( + output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + epsilon: float, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> None: + return None # quantization ops