Skip to content

Commit

Permalink
add back support for num_layers in T5Encoder and other minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankur-singh committed Jan 21, 2025
1 parent c3417b3 commit 696380d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 25 deletions.
33 changes: 16 additions & 17 deletions tests/torchtune/models/t5/test_t5_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,27 @@ def inputs(self):

def test_forward(self, model, inputs):
actual = model(inputs)
print(actual)
expected = torch.tensor(
[
[
[0.4958, 0.4845],
[0.4914, 0.4863],
[0.5089, 0.4791],
[0.5946, 0.4383],
[0.4754, 0.4925],
[0.6266, 0.4204],
[0.6327, 0.4167],
[0.6519, 0.4048],
[0.1940, 0.5625],
[0.1893, 0.5681],
[0.2020, 0.5522],
[0.2547, 0.4681],
[0.1769, 0.5822],
[0.2737, 0.4281],
[0.2828, 0.4066],
[0.2841, 0.4033],
],
[
[0.4769, 0.4919],
[0.5096, 0.4788],
[0.5347, 0.4679],
[0.6462, 0.4085],
[0.6643, 0.3968],
[0.5970, 0.4371],
[0.5829, 0.4445],
[0.4919, 0.4861],
[0.1796, 0.5792],
[0.2020, 0.5523],
[0.2209, 0.5258],
[0.2802, 0.4128],
[0.2923, 0.3817],
[0.2677, 0.4414],
[0.2458, 0.4847],
[0.1923, 0.5645],
],
]
)
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/gemma2/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def gemma2(
"""
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)

layers = nn.ModuleList()
layers = torch.nn.ModuleList()
for layer_idx in range(num_layers):

mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
Expand Down
32 changes: 25 additions & 7 deletions torchtune/models/t5/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import List, Union
from typing import List, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchtune.modules import MultiHeadAttention
from torchtune.modules.transformer import _get_clones


class T5Encoder(nn.Module):
Expand All @@ -21,7 +22,7 @@ class T5Encoder(nn.Module):
Args:
token_embedding (nn.Embedding): PyTorch embedding layer to place tokens in an embedding space.
layers (Union[List[nn.Module], nn.ModuleList]): A single encoder layer.
layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single encoder layer.
final_norm (nn.Module): Module that applies normalization to the output of the encoder
num_heads (int): The number of attention heads.
rel_pos_num_buckets (int): Number of discrete buckets to divide the relative positions into.
Expand All @@ -30,24 +31,29 @@ class T5Encoder(nn.Module):
Distances beyond this are grouped into the last bucket.
See: :class:`~torchtune.models.t5._encoder.T5EncoderRelativePositionBias`
max_seq_len (int): The maximum sequence length (context length) of the model.
num_layers (Optional[int]): Number of encoder layers, only define when layers is not a list.
Raises:
AssertionError:
If ``num_layers`` is set and layer is a list, **or**
``num_layers`` is not set and layer is an ``nn.Module``.
"""

def __init__(
self,
*,
token_embedding: nn.Embedding,
layers: Union[List[nn.Module], nn.ModuleList],
layers: Union[nn.Module, List[nn.Module], nn.ModuleList],
final_norm: nn.Module,
num_heads: int,
rel_pos_num_buckets: int,
rel_pos_max_dist: int,
max_seq_len: int,
):
num_layers: Optional[int] = None,
) -> None:
super().__init__()
self.token_embedding = token_embedding
self.layers = (
layers if isinstance(layers, nn.ModuleList) else nn.ModuleList(layers)
)
self.final_norm = final_norm
self.max_seq_len = max_seq_len
self.relative_position_bias = T5EncoderRelativePositionBias(
Expand All @@ -57,6 +63,18 @@ def __init__(
max_seq_len=max_seq_len,
)

self.layers = None
if isinstance(layers, nn.ModuleList):
self.layers = layers
elif isinstance(layers, list):
self.layers = nn.ModuleList(layers)
else:
if not isinstance(layers, nn.Module):
raise AssertionError("num_layers is defined, layers must be a module")
if num_layers is None:
raise AssertionError("num_layers is not defined, layers must be a list")
self.layers = _get_clones(layers, num_layers)

def forward(self, tokens: Tensor) -> Tensor:
"""
Args:
Expand Down

0 comments on commit 696380d

Please sign in to comment.