Skip to content

Commit

Permalink
allow for layers_execute_order to be overridden on forward
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 23, 2024
1 parent 33ea37a commit 7e73791
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.30.22',
version = '1.30.23',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
13 changes: 9 additions & 4 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,8 +1201,8 @@ def __init__(
rotary_xpos_scale_base = 512,
rotary_base_rescale_factor = 1.,
weight_tie_layers = False,
custom_layers: Tuple[str] | None = None,
layers_execute_order: Tuple[int] | None = None,
custom_layers: Tuple[str, ...] | None = None,
layers_execute_order: Tuple[int, ...] | None = None,
sandwich_coef = None,
par_ratio = None,
residual_attn = False,
Expand Down Expand Up @@ -1484,7 +1484,8 @@ def forward(
cache_age = 1,
return_hiddens = False,
rotary_pos_emb = None,
condition = None
condition = None,
layers_execute_order: Tuple[int, ...] | None = None
):
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
Expand Down Expand Up @@ -1576,7 +1577,11 @@ def forward(
self.layer_dropouts
)

layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
# able to override the layers execution order on forward, for trying to depth extrapolate

layers_execute_order = default(layers_execute_order, self.layers_execute_order)

layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)

# go through the attention and feedforward layers

Expand Down

0 comments on commit 7e73791

Please sign in to comment.