Skip to content

Commit

Permalink
Fixed rope for graph
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jan 2, 2024
1 parent ee7f574 commit 6461974
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def __init__(
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
dtype=weights.dtype,
)

self.softmax_scale = self.head_size**-0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
dtype=weights.dtype,
)

self.softmax_scale = self.head_size**-0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def __init__(
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
dtype=weights.dtype,
)

self.softmax_scale = self.head_size ** -0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
dim=config.rotary_dim,
base=rope_theta,
device=weights.device,
dtype=weights.dtype,
)

self.softmax_scale = self.head_size**-0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
dtype=weights.dtype,
)

self.softmax_scale = self.head_size**-0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self.head_size = self.hidden_size // self.num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
dim=self.head_size, base=10000.0, device=weights.device, dtype=weights.dtype
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
self.head_size = hidden_size // num_heads

self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
self.head_size, base=10000.0, device=weights.device, dtype=weights.dtype
)
self.softmax_scale = self.head_size ** (-0.5)

Expand Down
10 changes: 4 additions & 6 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def _get_rope_config(config):
return getattr(config, "rope_scaling", None)

class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor):
def __init__(self, inv_freq, scaling_factor, max_position_embeddings, device, dtype):
super().__init__()
self.inv_freq = inv_freq
self._seq_len_cached = 0
Expand All @@ -685,9 +685,10 @@ def __init__(self, inv_freq, scaling_factor):
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(dtype, device, max_position_embeddings)

@classmethod
def static(cls, config, dim, base, device):
def static(cls, config, dim, base, device, dtype):
inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None
rope_scaling = _get_rope_config(config)
Expand Down Expand Up @@ -717,7 +718,7 @@ def static(cls, config, dim, base, device):
raise NotImplementedError(
f"rope scaling type {rope_type} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)
return cls(inv_freq, scaling_factor, config.max_position_embeddings, device, dtype)

@classmethod
def load(cls, config, prefix, weights):
Expand Down Expand Up @@ -782,9 +783,6 @@ def get_cos_sin(
"""
Return cos and sin for the asked position ids
"""

self._update_cos_sin_cache(dtype, position_ids.device, max_s)

cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
Expand Down

0 comments on commit 6461974

Please sign in to comment.