Skip to content

Commit

Permalink
Revert "help out @cutoken at #159"
Browse files Browse the repository at this point in the history
This reverts commit 8fa7b4c.
  • Loading branch information
lucidrains committed Jun 30, 2023
1 parent baa2d6e commit a957b68
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.16.10',
version = '1.16.9',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
11 changes: 2 additions & 9 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,12 @@ def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.
scale_base = 512
):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor

if not use_xpos:
self.register_buffer('scale', None)
return
Expand All @@ -419,8 +415,6 @@ def __init__(

def forward(self, seq_len, device):
t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
t = t / self.interpolation_factor

freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)

Expand Down Expand Up @@ -893,7 +887,6 @@ def __init__(
rotary_emb_dim = None,
rotary_xpos = False,
rotary_xpos_scale_base = 512,
rotary_interpolation_factor = 1.,
custom_layers = None,
sandwich_coef = None,
par_ratio = None,
Expand Down Expand Up @@ -932,7 +925,7 @@ def __init__(
rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)

assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor) if rotary_pos_emb else None
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base) if rotary_pos_emb else None

assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
Expand Down

0 comments on commit a957b68

Please sign in to comment.