Skip to content

Commit d17c927

Browse files
Peng Chenfacebook-github-bot
authored andcommitted
add blip2 loss under torchmultimodal/modules/losses (#485)
Summary: Pull Request resolved: #485 as title Differential Revision: D50148648 fbshipit-source-id: 6d2d2de339a413c6f0ef152328d67249a8f7be84
1 parent b577b6f commit d17c927

File tree

2 files changed

+671
-0
lines changed

2 files changed

+671
-0
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import chain
8+
9+
import pytest
10+
import torch
11+
from tests.test_utils import (
12+
assert_expected,
13+
gpu_test,
14+
init_distributed_on_file,
15+
init_weights_with_constant,
16+
with_temp_files,
17+
)
18+
from torch import distributed as dist, multiprocessing as mp, nn, optim
19+
from torchmultimodal.models.blip2.blip2 import BLIP2, Blip2Output
20+
from torchmultimodal.models.blip2.qformer_model import QformerForCLM
21+
from torchmultimodal.modules.encoders.vision_transformer import VisionTransformer
22+
from torchmultimodal.modules.layers.patch_embedding import PatchEmbeddings
23+
from torchmultimodal.modules.layers.transformer import TransformerEncoder
24+
from torchmultimodal.modules.losses.blip2_losses import Blip2Phase1Loss
25+
26+
27+
@pytest.fixture
28+
def dim_q():
29+
return 4
30+
31+
32+
@pytest.fixture
33+
def dim_kv():
34+
return 2
35+
36+
37+
@pytest.fixture
38+
def vit():
39+
embedding = PatchEmbeddings(image_size=2, patch_size=1, hidden_size=2)
40+
encoder = TransformerEncoder(
41+
n_layer=1,
42+
d_model=2,
43+
n_head=1,
44+
dim_feedforward=1,
45+
activation=nn.GELU,
46+
norm_first=True,
47+
final_layer_norm_eps=1e-5,
48+
)
49+
image_encoder = VisionTransformer(
50+
embeddings=embedding,
51+
encoder=encoder,
52+
)
53+
init_weights_with_constant(image_encoder)
54+
image_encoder.eval()
55+
return image_encoder
56+
57+
58+
class TestBLIP2Stage1Loss:
59+
@pytest.fixture
60+
def images(self):
61+
return torch.ones(4, 3, 2, 2)
62+
63+
@pytest.fixture
64+
def input_ids(self):
65+
return torch.ones(4, 4).long()
66+
67+
@pytest.fixture
68+
def all_attn_mask(self):
69+
return torch.ones([4, 4])
70+
71+
@pytest.fixture
72+
def global_batch_size(self):
73+
return 4
74+
75+
@pytest.fixture
76+
def qformer_model_for_clm(
77+
self,
78+
dim_q,
79+
dim_kv,
80+
dim_feedforward,
81+
num_hidden_layers,
82+
num_heads,
83+
vocab_size,
84+
):
85+
qformer_for_clm = QformerForCLM(
86+
dim_q=dim_q,
87+
dim_kv=dim_kv,
88+
dim_feedforward=dim_feedforward,
89+
num_heads=num_heads,
90+
attn_dropout=0.0,
91+
dropout=0.0,
92+
num_hidden_layers=num_hidden_layers,
93+
max_position_embeddings=512,
94+
vocab_size=vocab_size,
95+
)
96+
return qformer_for_clm
97+
98+
@pytest.fixture
99+
def blip2_output(self):
100+
return Blip2Output(
101+
image_embeddings=torch.ones([4, 5, 2]),
102+
image_features=torch.ones([4, 32, 4]) * 0.5,
103+
image_qformer_output=torch.ones([4, 32, 4]) * 0.5,
104+
text_features=torch.ones([4, 4]) * 0.5,
105+
prediction_scores=torch.ones([4, 4, 20]) * 5,
106+
)
107+
108+
@pytest.fixture
109+
def blip2(self, dim_q, dim_kv, qformer_model_for_clm, vit):
110+
blip2 = BLIP2(
111+
dim_q=dim_q,
112+
image_encoder_embedding_dim=dim_kv,
113+
qformer=qformer_model_for_clm,
114+
vision_encoder=vit,
115+
embedding_dim=4,
116+
decoder_bos_token_id=19,
117+
)
118+
init_weights_with_constant(blip2)
119+
blip2.eval()
120+
return blip2
121+
122+
def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids):
123+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q)
124+
init_weights_with_constant(blip2_loss)
125+
local_loss = blip2_loss(
126+
model_output=blip2_output,
127+
blip2=blip2,
128+
input_ids=input_ids,
129+
attention_mask=all_attn_mask,
130+
)
131+
assert_expected(local_loss.total_loss.item(), 5.07517, rtol=0, atol=1e-4)
132+
133+
def test_local_itc_only_loss(
134+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
135+
):
136+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itm=False, enable_itg=False)
137+
init_weights_with_constant(blip2_loss)
138+
local_loss = blip2_loss(
139+
model_output=blip2_output,
140+
blip2=blip2,
141+
input_ids=input_ids,
142+
attention_mask=all_attn_mask,
143+
)
144+
assert_expected(local_loss.total_loss.item(), 1.38629, rtol=0, atol=1e-4)
145+
146+
def test_local_itm_only_loss(
147+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
148+
):
149+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itg=False)
150+
init_weights_with_constant(blip2_loss)
151+
local_loss = blip2_loss(
152+
model_output=blip2_output,
153+
blip2=blip2,
154+
input_ids=input_ids,
155+
attention_mask=all_attn_mask,
156+
)
157+
assert_expected(local_loss.total_loss.item(), 0.69315, rtol=0, atol=1e-4)
158+
159+
def test_local_itg_only_loss(
160+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
161+
):
162+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itm=False)
163+
init_weights_with_constant(blip2_loss)
164+
local_loss = blip2_loss(
165+
model_output=blip2_output,
166+
blip2=blip2,
167+
input_ids=input_ids,
168+
attention_mask=all_attn_mask,
169+
)
170+
assert_expected(local_loss.total_loss.item(), 2.9957, rtol=0, atol=1e-4)
171+
172+
def test_invalid_loss_input(self):
173+
with pytest.raises(ValueError):
174+
Blip2Phase1Loss(
175+
dim_q=dim_q, enable_itc=False, enable_itm=False, enable_itg=False
176+
)
177+
178+
@staticmethod
179+
def _model_worker(
180+
gpu_id: int,
181+
sync_file: str,
182+
world_size: int,
183+
global_batch_size: int,
184+
all_images: torch.Tensor,
185+
all_input_ids: torch.Tensor,
186+
all_attn_mask: torch.Tensor,
187+
blip2_output: Blip2Output,
188+
blip2: nn.Module,
189+
dim_q=dim_q,
190+
):
191+
init_distributed_on_file(
192+
world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
193+
)
194+
assert global_batch_size % world_size == 0
195+
local_batch_size = global_batch_size // world_size
196+
all_attn_mask = torch.ones([4, 4])
197+
198+
# Split inputs across GPUs
199+
local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id)
200+
local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda(
201+
gpu_id
202+
)
203+
local_attn_mask = torch.split(all_attn_mask, local_batch_size)[gpu_id].cuda(
204+
gpu_id
205+
)
206+
assert blip2_output.text_features is not None
207+
assert blip2_output.prediction_scores is not None
208+
local_blip2_output = Blip2Output(
209+
image_embeddings=torch.split(
210+
blip2_output.image_embeddings, local_batch_size
211+
)[gpu_id].cuda(gpu_id),
212+
image_features=torch.split(blip2_output.image_features, local_batch_size)[
213+
gpu_id
214+
].cuda(gpu_id),
215+
image_qformer_output=torch.split(
216+
blip2_output.image_qformer_output, local_batch_size
217+
)[gpu_id].cuda(gpu_id),
218+
text_features=torch.split(blip2_output.text_features, local_batch_size)[
219+
gpu_id
220+
].cuda(gpu_id),
221+
prediction_scores=torch.split(
222+
blip2_output.prediction_scores, local_batch_size
223+
)[gpu_id].cuda(gpu_id),
224+
)
225+
226+
blip2 = blip2.cuda(gpu_id)
227+
loss_fn = Blip2Phase1Loss(dim_q=dim_q)
228+
init_weights_with_constant(loss_fn)
229+
loss_fn = loss_fn.cuda(gpu_id)
230+
231+
all_params = chain(blip2.parameters(), loss_fn.parameters())
232+
233+
optimizer = optim.SGD(all_params, lr=1e-4)
234+
235+
# Forward pass
236+
loss = loss_fn(
237+
model_output=local_blip2_output,
238+
blip2=blip2,
239+
images=local_images,
240+
input_ids=local_input_ids,
241+
attention_mask=local_attn_mask,
242+
).total_loss
243+
244+
# Compute gradients
245+
optimizer.zero_grad()
246+
loss.backward()
247+
248+
# Gather gradients from all devices
249+
def gather_grads(x: torch.Tensor) -> torch.Tensor:
250+
grads = [torch.zeros_like(x).cuda(gpu_id) for i in range(world_size)]
251+
dist.all_gather(grads, x)
252+
grad = torch.stack(grads).mean()
253+
return grad
254+
255+
# Gather losses from all devices
256+
gathered_loss = gather_grads(torch.Tensor([loss]).cuda(gpu_id))
257+
assert_expected(gathered_loss.item(), 5.07517, rtol=0, atol=1e-4)
258+
259+
@gpu_test(gpu_count=1)
260+
def test_single_gpu_loss(
261+
self,
262+
global_batch_size,
263+
input_ids,
264+
blip2_output,
265+
blip2,
266+
attn_mask,
267+
dim_q,
268+
):
269+
with with_temp_files(count=1) as sync_file:
270+
world_size = 1
271+
mp.spawn(
272+
TestBLIP2Stage1Loss._model_worker,
273+
(
274+
sync_file,
275+
world_size,
276+
global_batch_size,
277+
input_ids,
278+
attn_mask,
279+
blip2_output,
280+
blip2,
281+
dim_q,
282+
),
283+
nprocs=world_size,
284+
)
285+
286+
@gpu_test(gpu_count=2)
287+
def test_multi_gpu_loss(
288+
self,
289+
global_batch_size,
290+
input_ids,
291+
blip2_output,
292+
blip2,
293+
attn_mask,
294+
dim_q,
295+
):
296+
with with_temp_files(count=1) as sync_file:
297+
world_size = 2
298+
mp.spawn(
299+
TestBLIP2Stage1Loss._model_worker,
300+
(
301+
sync_file,
302+
world_size,
303+
global_batch_size,
304+
input_ids,
305+
attn_mask,
306+
blip2_output,
307+
blip2,
308+
dim_q,
309+
),
310+
nprocs=world_size,
311+
)

0 commit comments

Comments
 (0)