Skip to content

Commit ee15f59

Browse files
Ankita Defacebook-github-bot
authored andcommitted
Copy transformer impl to oss folder
Differential Revision: D48132386 fbshipit-source-id: 4838e98a992313040fabde82d2082c48cb802fd5
1 parent 81e281c commit ee15f59

File tree

2 files changed

+433
-4
lines changed

2 files changed

+433
-4
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
9+
import torch
10+
from tests.test_utils import assert_expected, init_weights_with_constant, set_rng_seed
11+
from torch import Tensor
12+
from torchmultimodal.modules.layers.transformer import (
13+
TransformerEncoder,
14+
TransformerEncoderLayer,
15+
)
16+
from torchmultimodal.modules.layers.transformer import TransformerOutput
17+
18+
19+
@pytest.fixture(autouse=True)
20+
def random():
21+
set_rng_seed(4)
22+
23+
24+
class TestTransformerEncoderLayer:
25+
@pytest.fixture
26+
def get_encoder_layer(self):
27+
def create_layer(norm_first):
28+
model = TransformerEncoderLayer(
29+
d_model=2,
30+
n_head=1,
31+
dim_feedforward=2,
32+
norm_first=norm_first,
33+
)
34+
init_weights_with_constant(model)
35+
model.eval()
36+
return model
37+
38+
return create_layer
39+
40+
@pytest.fixture
41+
def inputs(self):
42+
return Tensor([[[1, 2], [4, 2], [1, 1]]])
43+
44+
@pytest.mark.parametrize(
45+
"norm_first, expected_output",
46+
[
47+
(True, Tensor([[[15.0, 16.0], [18.0, 16.0], [15.0, 15.0]]])),
48+
(False, Tensor([[[0.0, 2.0], [2.0, 0.0], [1.0, 1.0]]])),
49+
],
50+
)
51+
def test_forward(self, norm_first, expected_output, inputs, get_encoder_layer):
52+
model = get_encoder_layer(norm_first)
53+
actual = model(inputs)
54+
assert_expected(actual, expected_output, rtol=0, atol=1e-4)
55+
56+
@pytest.mark.parametrize(
57+
"norm_first",
58+
[(True,), (False,)],
59+
)
60+
def test_scripting(self, norm_first, inputs, get_encoder_layer):
61+
model = get_encoder_layer(norm_first)
62+
scripted_model = torch.jit.script(model)
63+
assert_expected(scripted_model(inputs), model(inputs), rtol=0, atol=1e-4)
64+
65+
66+
class TestTransformerEncoder:
67+
@pytest.fixture
68+
def get_encoder(self):
69+
def create_encoder(norm_first, final_layer_norm_eps=None):
70+
model = TransformerEncoder(
71+
n_layer=2,
72+
d_model=2,
73+
n_head=1,
74+
dim_feedforward=2,
75+
norm_first=norm_first,
76+
final_layer_norm_eps=final_layer_norm_eps,
77+
)
78+
init_weights_with_constant(model)
79+
model.eval()
80+
return model
81+
82+
return create_encoder
83+
84+
@pytest.fixture
85+
def inputs(self):
86+
return Tensor([[[2, 3], [1, 2]]])
87+
88+
@pytest.mark.parametrize(
89+
"norm_first, return_hidden_states, expected_output",
90+
[
91+
(
92+
True,
93+
False,
94+
TransformerOutput(
95+
last_hidden_state=Tensor([[[30.0, 31.0], [29.0, 30.0]]])
96+
),
97+
),
98+
(
99+
False,
100+
False,
101+
TransformerOutput(last_hidden_state=Tensor([[[0.0, 2.0], [0.0, 2.0]]])),
102+
),
103+
(
104+
True,
105+
True,
106+
TransformerOutput(
107+
last_hidden_state=Tensor([[[30.0, 31.0], [29.0, 30.0]]]),
108+
hidden_states=[
109+
Tensor([[[16.0, 17.0], [15.0, 16.0]]]),
110+
Tensor([[[30.0, 31.0], [29.0, 30.0]]]),
111+
],
112+
),
113+
),
114+
(
115+
False,
116+
True,
117+
TransformerOutput(
118+
last_hidden_state=Tensor([[[0.0, 2.0], [0.0, 2.0]]]),
119+
hidden_states=[
120+
Tensor([[[0.0, 2.0], [0.0, 2.0]]]),
121+
Tensor([[[0.0, 2.0], [0.0, 2.0]]]),
122+
],
123+
),
124+
),
125+
],
126+
)
127+
def test_forward(
128+
self, norm_first, return_hidden_states, expected_output, inputs, get_encoder
129+
):
130+
model = get_encoder(norm_first)
131+
actual = model(inputs, return_hidden_states=return_hidden_states)
132+
if expected_output.hidden_states is None:
133+
assert actual.hidden_states is None
134+
else:
135+
assert_expected(actual.hidden_states[0], inputs)
136+
for state_1, state_2 in zip(
137+
expected_output.hidden_states, actual.hidden_states[1:]
138+
):
139+
assert_expected(state_1, state_2)
140+
141+
assert actual.attentions == expected_output.attentions
142+
assert_expected(
143+
actual.last_hidden_state,
144+
expected_output.last_hidden_state,
145+
rtol=0,
146+
atol=1e-4,
147+
)
148+
149+
@pytest.mark.parametrize(
150+
"norm_first, expected_output",
151+
[
152+
(
153+
True,
154+
TransformerOutput(
155+
last_hidden_state=Tensor([[[1.9073e-05, 2.0], [2.2888e-05, 2.0]]]),
156+
hidden_states=[
157+
Tensor([[[16.0, 17.0], [15.0, 16.0]]]),
158+
Tensor([[[30.0, 31.0], [29.0, 30.0]]]),
159+
],
160+
),
161+
),
162+
(
163+
False,
164+
TransformerOutput(
165+
last_hidden_state=Tensor([[[5.0068e-06, 2.0], [5.0068e-06, 2.0]]]),
166+
hidden_states=[
167+
Tensor([[[0.0, 2.0], [0.0, 2.0]]]),
168+
Tensor([[[0.0, 2.0], [0.0, 2.0]]]),
169+
],
170+
),
171+
),
172+
],
173+
)
174+
def test_forward_with_final_ln(
175+
self, norm_first, expected_output, inputs, get_encoder
176+
):
177+
model = get_encoder(norm_first=norm_first, final_layer_norm_eps=1e-5)
178+
actual = model(inputs, return_hidden_states=True)
179+
assert_expected(
180+
expected_output.last_hidden_state,
181+
actual.last_hidden_state,
182+
rtol=0,
183+
atol=1e-4,
184+
)
185+
for state_1, state_2 in zip(
186+
expected_output.hidden_states, actual.hidden_states[1:]
187+
):
188+
assert_expected(state_1, state_2)
189+
190+
@pytest.mark.parametrize(
191+
"norm_first",
192+
[(True,), (False,)],
193+
)
194+
def test_scripting(self, norm_first, inputs, get_encoder):
195+
model = get_encoder(norm_first)
196+
scripted_model = torch.jit.script(model)
197+
assert_expected(scripted_model(inputs), model(inputs), rtol=0, atol=1e-4)

0 commit comments

Comments
 (0)