diff --git a/setup.py b/setup.py index 9e3111ab..01edba16 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.30.22', + version = '1.30.23', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 5876ae37..adad800a 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1201,8 +1201,8 @@ def __init__( rotary_xpos_scale_base = 512, rotary_base_rescale_factor = 1., weight_tie_layers = False, - custom_layers: Tuple[str] | None = None, - layers_execute_order: Tuple[int] | None = None, + custom_layers: Tuple[str, ...] | None = None, + layers_execute_order: Tuple[int, ...] | None = None, sandwich_coef = None, par_ratio = None, residual_attn = False, @@ -1484,7 +1484,8 @@ def forward( cache_age = 1, return_hiddens = False, rotary_pos_emb = None, - condition = None + condition = None, + layers_execute_order: Tuple[int, ...] | None = None ): assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa' @@ -1576,7 +1577,11 @@ def forward( self.layer_dropouts ) - layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables) + # able to override the layers execution order on forward, for trying to depth extrapolate + + layers_execute_order = default(layers_execute_order, self.layers_execute_order) + + layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables) # go through the attention and feedforward layers