Skip to content

Commit 1aa2ed2

Browse files
ankitadefacebook-github-bot
authored andcommitted
Move transformer impl to flava folder (#439)
Summary: Only used by flava so moving it to that folder to make space for the other transformer Pull Request resolved: #439 Test Plan: pytest tests/ sanity check torchrun --nproc_per_node=1 -m flava.native.train config=flava/native/configs/pretrain_debug.yaml Reviewed By: ebsmothers Differential Revision: D47839480 Pulled By: ankitade fbshipit-source-id: ed635ae192baa16ee1a244b2e8e59bad37a80ad4
1 parent 12bb2bc commit 1aa2ed2

File tree

11 files changed

+249
-244
lines changed

11 files changed

+249
-244
lines changed

examples/flava/native/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@
5353
from torch.utils.tensorboard import SummaryWriter
5454
from torchmultimodal.models.flava.image_encoder import ImageTransformer
5555
from torchmultimodal.models.flava.text_encoder import BERTTextEncoder
56-
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
57-
from torchmultimodal.modules.layers.transformer import TransformerEncoderLayer
56+
from torchmultimodal.models.flava.transformer import (
57+
FLAVATransformerWithoutEmbeddings,
58+
TransformerEncoderLayer,
59+
)
5860
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput
5961

6062

examples/flava/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ omegaconf==2.1.2
77
hydra-core==1.1.2
88
transformers==4.30.0
99
pycocotools==2.0.4
10+
tensorboard

tests/models/flava/test_image_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tests.test_utils import assert_expected, set_rng_seed
1111
from torch import nn
1212
from torchmultimodal.models.flava.image_encoder import ImageEmbeddings, ImageTransformer
13-
from torchmultimodal.modules.layers.transformer import TransformerEncoder
13+
from torchmultimodal.models.flava.transformer import TransformerEncoder
1414

1515

1616
@pytest.fixture(autouse=True)

tests/models/flava/test_text_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import torch
1212
from tests.test_utils import assert_expected, set_rng_seed
1313
from torch import nn
14-
from torchmultimodal.models.flava.transformer import init_transformer_weights
14+
from torchmultimodal.models.flava.transformer import (
15+
init_transformer_weights,
16+
TransformerEncoder,
17+
)
1518
from torchmultimodal.modules.encoders.bert_text_encoder import BERTTextEncoder
1619
from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
17-
from torchmultimodal.modules.layers.transformer import TransformerEncoder
1820

1921

2022
@pytest.fixture(autouse=True)

tests/modules/layers/test_transformer.py renamed to tests/models/flava/test_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from tests.test_utils import assert_expected, set_rng_seed
1111
from torch import nn
12-
from torchmultimodal.modules.layers.transformer import (
12+
from torchmultimodal.models.flava.transformer import (
1313
TransformerEncoder,
1414
TransformerEncoderLayer,
1515
)

torchmultimodal/models/flava/image_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
import torch
1313
from torch import nn, Tensor
14-
from torchmultimodal.models.flava.transformer import init_transformer_weights
15-
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
16-
from torchmultimodal.modules.layers.transformer import (
14+
from torchmultimodal.models.flava.transformer import (
15+
init_transformer_weights,
1716
TransformerEncoder,
18-
TransformerOutput,
1917
)
18+
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
19+
from torchmultimodal.modules.layers.transformer import TransformerOutput
2020
from torchmultimodal.modules.losses.flava import Pooler
2121

2222

torchmultimodal/models/flava/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from torch import nn, Tensor
1818
from torchmultimodal.models.flava.image_encoder import flava_image_encoder
1919
from torchmultimodal.models.flava.text_encoder import flava_text_encoder
20-
from torchmultimodal.models.flava.transformer import FLAVATransformerWithoutEmbeddings
21-
from torchmultimodal.modules.layers.mlp import MLP
22-
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
23-
from torchmultimodal.modules.layers.transformer import (
20+
from torchmultimodal.models.flava.transformer import (
21+
FLAVATransformerWithoutEmbeddings,
2422
TransformerEncoder,
25-
TransformerOutput,
2623
)
24+
from torchmultimodal.modules.layers.mlp import MLP
25+
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
26+
from torchmultimodal.modules.layers.transformer import TransformerOutput
2727
from torchmultimodal.modules.losses.flava import (
2828
FLAVAPretrainingLoss,
2929
FLAVAPretrainingLossOutput,

torchmultimodal/models/flava/text_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from typing import Callable
99

1010
from torch import nn
11-
from torchmultimodal.models.flava.transformer import init_transformer_weights
11+
from torchmultimodal.models.flava.transformer import (
12+
init_transformer_weights,
13+
TransformerEncoder,
14+
)
1215
from torchmultimodal.modules.encoders.bert_text_encoder import BERTTextEncoder
1316
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
1417
from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
15-
from torchmultimodal.modules.layers.transformer import TransformerEncoder
1618
from torchmultimodal.modules.losses.flava import Pooler
1719

1820

torchmultimodal/models/flava/transformer.py

Lines changed: 221 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from functools import partial
8-
from typing import Any, Callable, Optional
8+
from typing import Any, Callable, Optional, Tuple, Union
99

1010
import torch
1111
from torch import nn, Tensor
12+
from torchmultimodal.modules.layers.attention import MultiHeadAttention, SelfAttention
13+
from torchmultimodal.modules.layers.mlp import MLP
14+
from torchmultimodal.modules.layers.normalizations import Fp32LayerNorm
1215

1316
from torchmultimodal.modules.layers.transformer import TransformerOutput
1417

@@ -75,6 +78,223 @@ def forward(
7578
)
7679

7780

81+
class TransformerEncoderLayer(nn.Module):
82+
"""Transformer encoder layer is made up of multihead self-attention and feedforward blocks,
83+
based on the architecture in "Attention Is All You Need" (Vaswani et al. 2017). Similar to
84+
``nn.TransformerEncoderLayer``, but uses a custom ``MultiHeadAttention`` that supports
85+
n-dimensional inputs (including sequences, images, video) and head-masking.
86+
87+
Attributes:
88+
d_model (int): size of hidden dimension of input
89+
n_head (int): number of attention heads
90+
dim_feedforward (int): size of hidden dimension of feedforward network
91+
dropout (float): dropout probability for all dropouts. Defaults to 0.
92+
activation (Callable): activation function in feedforward network. Defaults to ``nn.ReLU``.
93+
layer_norm_eps (float): the eps value in layer norms. Default is 1e-12.
94+
norm_first (bool): if True, layer norm is done prior to each of self-attention, cross-attention,
95+
and feedforward. Otherwise, layer norm is done after.
96+
97+
Args:
98+
hidden_states (Tensor): input tensor of shape [b, d1, ..., dn, c] to calculate self-attention on.
99+
attention_mask (Tensor, optional): mask to be applied to self-attention inputs, ``hidden_states``. See
100+
``MultiHeadAttention`` for shape requirements.
101+
head_mask (Tensor, optional): mask to be applied to self-attention inputs after softmax and dropout,
102+
before matrix multiplication with values. See ``MultiHeadAttention`` for shape requirements.
103+
return_attn_weights (bool, optional): return attention probabilities in addition to attention output.
104+
Defaults to False.
105+
"""
106+
107+
def __init__(
108+
self,
109+
d_model: int,
110+
n_head: int,
111+
dim_feedforward: int,
112+
dropout: float = 0.0,
113+
activation: Callable[..., nn.Module] = nn.ReLU,
114+
layer_norm_eps: float = 1e-12,
115+
norm_first: bool = False,
116+
) -> None:
117+
super().__init__()
118+
# attention block
119+
self.attention = MultiHeadAttention(
120+
dim_q=d_model,
121+
dim_kv=d_model,
122+
n_head=n_head,
123+
attn_module=SelfAttention(dropout),
124+
)
125+
self.attention_dropout = nn.Dropout(dropout)
126+
# feedforward block
127+
self.feedforward = MLP(
128+
d_model, d_model, dim_feedforward, dropout=dropout, activation=activation
129+
)
130+
self.feedforward_dropout = nn.Dropout(dropout)
131+
# layernorms
132+
self.attention_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps)
133+
self.feedforward_layernorm = Fp32LayerNorm(d_model, eps=layer_norm_eps)
134+
self.norm_first = norm_first
135+
136+
def _attention_block(
137+
self,
138+
hidden_states: Tensor,
139+
attention_mask: Optional[Tensor] = None,
140+
head_mask: Optional[Tensor] = None,
141+
) -> Tuple[Tensor, Tensor]:
142+
output, attn_weights = self.attention(
143+
hidden_states,
144+
attention_mask=attention_mask,
145+
head_mask=head_mask,
146+
return_attn_weights=True,
147+
)
148+
output = self.attention_dropout(output)
149+
return output, attn_weights
150+
151+
def _feedforward_block(self, hidden_states: Tensor) -> Tensor:
152+
h = self.feedforward(hidden_states)
153+
h = self.feedforward_dropout(h)
154+
return h
155+
156+
def _forward_prenorm(
157+
self,
158+
hidden_states: Tensor,
159+
attention_mask: Optional[Tensor] = None,
160+
head_mask: Optional[Tensor] = None,
161+
return_attn_weights: bool = False,
162+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
163+
x = hidden_states
164+
inputs = self.attention_layernorm(x)
165+
attn_output, attn_weights = self._attention_block(
166+
inputs,
167+
attention_mask=attention_mask,
168+
head_mask=head_mask,
169+
)
170+
attn_residual = attn_output + x
171+
ff_residual = attn_residual + self._feedforward_block(
172+
self.feedforward_layernorm(attn_residual)
173+
)
174+
if return_attn_weights:
175+
return ff_residual, attn_weights
176+
else:
177+
return ff_residual
178+
179+
def _forward_postnorm(
180+
self,
181+
hidden_states: Tensor,
182+
attention_mask: Optional[Tensor] = None,
183+
head_mask: Optional[Tensor] = None,
184+
return_attn_weights: bool = False,
185+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
186+
x = hidden_states
187+
attn_output, attn_weights = self._attention_block(
188+
x,
189+
attention_mask=attention_mask,
190+
head_mask=head_mask,
191+
)
192+
attn_residual = attn_output + x
193+
attn_residual = self.attention_layernorm(attn_residual)
194+
ff_residual = attn_residual + self._feedforward_block(attn_residual)
195+
outputs = self.feedforward_layernorm(ff_residual)
196+
if return_attn_weights:
197+
return outputs, attn_weights
198+
else:
199+
return outputs
200+
201+
def forward(
202+
self,
203+
hidden_states: Tensor,
204+
attention_mask: Optional[Tensor] = None,
205+
head_mask: Optional[Tensor] = None,
206+
return_attn_weights: bool = False,
207+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
208+
if self.norm_first:
209+
return self._forward_prenorm(
210+
hidden_states,
211+
attention_mask,
212+
head_mask,
213+
return_attn_weights,
214+
)
215+
else:
216+
return self._forward_postnorm(
217+
hidden_states,
218+
attention_mask,
219+
head_mask,
220+
return_attn_weights,
221+
)
222+
223+
224+
class TransformerEncoder(nn.Module):
225+
def __init__(
226+
self,
227+
n_layer: int,
228+
d_model: int,
229+
n_head: int,
230+
dim_feedforward: int,
231+
dropout: float = 0.0,
232+
activation: Callable[..., nn.Module] = nn.ReLU,
233+
layer_norm_eps: float = 1e-12,
234+
norm_first: bool = False,
235+
final_layer_norm_eps: Optional[float] = None,
236+
):
237+
super().__init__()
238+
self.layer = nn.ModuleList(
239+
[
240+
TransformerEncoderLayer(
241+
d_model,
242+
n_head,
243+
dim_feedforward,
244+
dropout,
245+
activation,
246+
layer_norm_eps,
247+
norm_first,
248+
)
249+
for _ in range(n_layer)
250+
]
251+
)
252+
self.final_layer_norm = None
253+
if final_layer_norm_eps:
254+
self.final_layer_norm = Fp32LayerNorm(d_model, eps=final_layer_norm_eps)
255+
256+
def forward(
257+
self,
258+
hidden_states: Tensor,
259+
attention_mask: Optional[Tensor] = None,
260+
head_mask: Optional[Tensor] = None,
261+
return_attn_weights: bool = False,
262+
return_hidden_states: bool = False,
263+
) -> TransformerOutput:
264+
265+
all_hidden_states = [] if return_hidden_states else None
266+
all_self_attentions = [] if return_attn_weights else None
267+
268+
for layer_module in self.layer:
269+
if return_hidden_states:
270+
all_hidden_states.append(hidden_states)
271+
272+
layer_outputs = layer_module(
273+
hidden_states,
274+
attention_mask=attention_mask,
275+
head_mask=head_mask,
276+
return_attn_weights=return_attn_weights,
277+
)
278+
279+
if return_attn_weights:
280+
hidden_states = layer_outputs[0]
281+
all_self_attentions.append(layer_outputs[1])
282+
else:
283+
hidden_states = layer_outputs
284+
285+
if return_hidden_states:
286+
all_hidden_states.append(hidden_states)
287+
288+
if self.final_layer_norm is not None:
289+
hidden_states = self.final_layer_norm(hidden_states)
290+
291+
return TransformerOutput(
292+
last_hidden_state=hidden_states,
293+
hidden_states=all_hidden_states,
294+
attentions=all_self_attentions,
295+
)
296+
297+
78298
def init_transformer_weights(module: nn.Module, initializer_range: float) -> None:
79299
"""Initialize the weights"""
80300
if isinstance(module, (nn.Linear, nn.Conv2d)):

torchmultimodal/modules/encoders/bert_text_encoder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
import torch
1010
from torch import nn, Tensor
11+
from torchmultimodal.models.flava.transformer import TransformerEncoder
1112
from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
12-
from torchmultimodal.modules.layers.transformer import (
13-
TransformerEncoder,
14-
TransformerOutput,
15-
)
13+
from torchmultimodal.modules.layers.transformer import TransformerOutput
1614
from torchmultimodal.utils.attention import get_extended_attention_mask
1715

1816

0 commit comments

Comments
 (0)