diff --git a/setup.py b/setup.py index 8c5f5349..e86fe4e9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.42.17', + version = '1.42.18', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', @@ -16,10 +16,11 @@ 'transformers' ], install_requires=[ - 'torch>=2.0', 'einx>=0.3.0', 'einops>=0.8.0', + 'loguru', 'packaging>=21.0', + 'torch>=2.0', ], setup_requires=[ 'pytest-runner', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 66b803a9..966564a1 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -20,6 +20,8 @@ from einops.layers.torch import Rearrange from einops import rearrange, repeat, reduce, pack, unpack +from loguru import logger + from x_transformers.attend import Attend, Intermediates from x_transformers.autoregressive_wrapper import AutoregressiveWrapper @@ -1580,7 +1582,12 @@ def __init__( self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb)) - rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + rotary_emb_dim = default(rotary_emb_dim, dim_head // 2) + + assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}' + + if rotary_emb_dim < 32: + logger.warning('when training language model, rotary embedding dimension should be at least 32') assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention' self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None