diff --git a/aphrodite/modeling/layers/layernorm.py b/aphrodite/modeling/layers/layernorm.py
index 623e16da7..86a7bbc87 100644
--- a/aphrodite/modeling/layers/layernorm.py
+++ b/aphrodite/modeling/layers/layernorm.py
@@ -137,10 +137,12 @@ def __init__(
         self.weight = nn.Parameter(torch.zeros(hidden_size))
         self.variance_epsilon = eps
 
-    def forward_native(
-        self,
+    @staticmethod
+    def forward_static(
+        weight: torch.Tensor,
+        variance_epsilon: float,
         x: torch.Tensor,
-        residual: Optional[torch.Tensor] = None,
+        residual: Optional[torch.Tensor],
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         """PyTorch-native implementation equivalent to forward()."""
         orig_dtype = x.dtype
@@ -150,17 +152,31 @@ def forward_native(
 
         x = x.float()
         variance = x.pow(2).mean(dim=-1, keepdim=True)
-        x = x * torch.rsqrt(variance + self.variance_epsilon)
+        x = x * torch.rsqrt(variance + variance_epsilon)
         # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
         # See https://github.com/huggingface/transformers/pull/29402
-        x = x * (1.0 + self.weight.float())
+        x = x * (1.0 + weight.float())
         x = x.to(orig_dtype)
         return x if residual is None else (x, residual)
 
+    def forward_native(
+        self,
+        x: torch.Tensor,
+        residual: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        """PyTorch-native implementation equivalent to forward()."""
+        return self.forward_static(self.weight.data, self.variance_epsilon, x,
+                                   residual)
+
     def forward_cuda(
         self,
         x: torch.Tensor,
         residual: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
-        # TODO: Implement an optimized kernel for GemmaRMSNorm.
+        if torch.compiler.is_compiling():
+            return self.forward_native(x, residual)
+        if not getattr(self, "_is_compiled", False):
+            self.forward_static = torch.compile(  # type: ignore
+                self.forward_static)
+            self._is_compiled = True
         return self.forward_native(x, residual)