Skip to content

Commit

Permalink
address #296
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 26, 2024
1 parent 5791a85 commit b720245
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.17',
version = '1.42.18',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand All @@ -16,10 +16,11 @@
'transformers'
],
install_requires=[
'torch>=2.0',
'einx>=0.3.0',
'einops>=0.8.0',
'loguru',
'packaging>=21.0',
'torch>=2.0',
],
setup_requires=[
'pytest-runner',
Expand Down
9 changes: 8 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack

from loguru import logger

from x_transformers.attend import Attend, Intermediates
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

Expand Down Expand Up @@ -1580,7 +1582,12 @@ def __init__(

self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))

rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)

assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}'

if rotary_emb_dim < 32:
logger.warning('when training language model, rotary embedding dimension should be at least 32')

assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
Expand Down

0 comments on commit b720245

Please sign in to comment.