Skip to content

Commit

Permalink
default the depth to length of computed execution order of layers
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 10, 2024
1 parent 3abed69 commit 58fb026
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ model = TransformerWrapper(
)
```

If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next, that is possible as well.
If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next, that is possible as well. Be aware the `layers_execute_order` is 0-indexed

```python
import torch
Expand All @@ -716,11 +716,6 @@ model = TransformerWrapper(
)
)
)

x = torch.randint(0, 256, (1, 1024))

model(x) # (1, 1024, 20000)

```

### Understanding and Improving Transformer From a Multi-Particle Dynamic System Point of View
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.29.0',
version = '1.29.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
6 changes: 5 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,6 @@ def __init__(
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)

self.dim = dim
self.depth = depth
self.causal = causal
self.layers = ModuleList([])

Expand Down Expand Up @@ -1179,6 +1178,11 @@ def __init__(

self.num_attn_layers = len(list(filter(equals('a'), layer_types)))

# set the depth

depth = default(depth, len(self.layers_execute_order))
self.depth = depth

# stochastic depth

self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
Expand Down

0 comments on commit 58fb026

Please sign in to comment.