diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index f8ba46b2ea..1e33819e9f 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -4051,6 +4051,7 @@ def __init__( rotary_percent: float = 1.0, seq_len_interpolation_factor: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None, + rotary_base: float = 10000.0, ): """ Parameters @@ -4069,8 +4070,9 @@ def __init__( if rotary_percent < 1.0: dim = int(dim * rotary_percent) self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.rotary_base = rotary_base inv_freq = 1.0 / ( - 10000 + self.rotary_base ** ( torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim