From 58fb0262d026250edc9de86e6a631e7a890de31e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 10 May 2024 09:48:20 -0700 Subject: [PATCH] default the depth to length of computed execution order of layers --- README.md | 7 +------ setup.py | 2 +- x_transformers/x_transformers.py | 6 +++++- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index f8466512..948109eb 100644 --- a/README.md +++ b/README.md @@ -693,7 +693,7 @@ model = TransformerWrapper( ) ``` -If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next, that is possible as well. +If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next, that is possible as well. Be aware the `layers_execute_order` is 0-indexed ```python import torch @@ -716,11 +716,6 @@ model = TransformerWrapper( ) ) ) - -x = torch.randint(0, 256, (1, 1024)) - -model(x) # (1, 1024, 20000) - ``` ### Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View diff --git a/setup.py b/setup.py index 5c2b1fcd..cfad771c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.29.0', + version = '1.29.2', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 56ec37e7..b492e444 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1059,7 +1059,6 @@ def __init__( dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) self.dim = dim - self.depth = depth self.causal = causal self.layers = ModuleList([]) @@ -1179,6 +1178,11 @@ def __init__( self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + # set the depth + + depth = default(depth, len(self.layers_execute_order)) + self.depth = depth + # stochastic depth self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))