diff --git a/server/lorax_server/layers/rotary.py b/server/lorax_server/layers/rotary.py index 83473ded8..6c88d87e2 100644 --- a/server/lorax_server/layers/rotary.py +++ b/server/lorax_server/layers/rotary.py @@ -101,7 +101,7 @@ def static(cls, config, dim, base, device): beta_fast=32, beta_slow=1, ) - elif rope_scaling["type"] == "su": + elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor(rope_scaling["short_factor"], dtype=torch.float32, device=device) short_inv_freq = 1.0 / ( short_factor * base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index eab91c3ba..d358c317b 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -494,7 +494,7 @@ def static(cls, config, dim, base, device, dtype): dtype=dtype, **rope_scaling, ) - elif rope_type == "su": + elif rope_type in ["su", "longrope"]: short_factor = torch.tensor(rope_scaling["short_factor"], dtype=torch.float32, device=device) short_inv_freq = 1.0 / ( short_factor * base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)