Skip to content

Commit

Permalink
Merge branch 'main' into linear-cpu-perf-optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 committed Sep 19, 2024
2 parents 2c464c5 + c0caadb commit 5905b01
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4051,6 +4051,7 @@ def __init__(
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
):
"""
Parameters
Expand All @@ -4069,8 +4070,9 @@ def __init__(
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.rotary_base = rotary_base
inv_freq = 1.0 / (
10000
self.rotary_base
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
Expand Down

0 comments on commit 5905b01

Please sign in to comment.