From bfc3da41ae746f1aaa1f0129200eab896af12f3a Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Sun, 15 Dec 2024 17:05:09 -0800 Subject: [PATCH] feat: add torch.compile for GemmaRMSNorm (#898) --- aphrodite/modeling/layers/layernorm.py | 28 ++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/aphrodite/modeling/layers/layernorm.py b/aphrodite/modeling/layers/layernorm.py index 623e16da7..86a7bbc87 100644 --- a/aphrodite/modeling/layers/layernorm.py +++ b/aphrodite/modeling/layers/layernorm.py @@ -137,10 +137,12 @@ def __init__( self.weight = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps - def forward_native( - self, + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, x: torch.Tensor, - residual: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor], ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype @@ -150,17 +152,31 @@ def forward_native( x = x.float() variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * torch.rsqrt(variance + variance_epsilon) # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + self.weight.float()) + x = x * (1.0 + weight.float()) x = x.to(orig_dtype) return x if residual is None else (x, residual) + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # TODO: Implement an optimized kernel for GemmaRMSNorm. + if torch.compiler.is_compiling(): + return self.forward_native(x, residual) + if not getattr(self, "_is_compiled", False): + self.forward_static = torch.compile( # type: ignore + self.forward_static) + self._is_compiled = True return self.forward_native(x, residual)