Skip to content

Commit 331ecd7

Browse files
Peng Chenfacebook-github-bot
authored andcommitted
add blip2 loss under torchmultimodal/modules/losses (#485)
Summary: Pull Request resolved: #485 as title Differential Revision: D50148648 fbshipit-source-id: c5f0525e89b1c8faaab484d122736a8783f3af92
1 parent 07b8bdc commit 331ecd7

File tree

2 files changed

+618
-0
lines changed

2 files changed

+618
-0
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import chain
8+
9+
import pytest
10+
import torch
11+
from tests.test_utils import (
12+
assert_expected,
13+
gpu_test,
14+
init_distributed_on_file,
15+
init_weights_with_constant,
16+
with_temp_files,
17+
)
18+
from torch import distributed as dist, multiprocessing as mp, nn, optim
19+
from torchmultimodal.models.blip2.blip2 import BLIP2, Blip2Output
20+
from torchmultimodal.modules.losses.blip2_losses import Blip2Phase1Loss
21+
22+
23+
@pytest.fixture
24+
def dim_q():
25+
return 4
26+
27+
28+
class TestBLIP2Stage1Loss:
29+
@pytest.fixture
30+
def images(self):
31+
return torch.ones(4, 3, 2, 2)
32+
33+
@pytest.fixture
34+
def input_ids(self):
35+
return torch.ones(4, 4).long()
36+
37+
@pytest.fixture
38+
def all_attn_mask(self):
39+
return torch.ones([4, 4])
40+
41+
@pytest.fixture
42+
def global_batch_size(self):
43+
return 4
44+
45+
@pytest.fixture
46+
def blip2_output(self):
47+
return Blip2Output(
48+
image_embeddings=torch.ones([4, 5, 2]),
49+
image_features=torch.ones([4, 32, 4]) * 0.5,
50+
image_qformer_output=torch.ones([4, 32, 4]) * 0.5,
51+
text_features=torch.ones([4, 4]) * 0.5,
52+
prediction_scores=torch.ones([4, 4, 20]) * 5,
53+
)
54+
55+
@pytest.fixture
56+
def blip2(self, dim_q, dim_kv, qformer_model_for_clm, vit):
57+
blip2 = BLIP2(
58+
dim_q=dim_q,
59+
image_encoder_embedding_dim=dim_kv,
60+
qformer=qformer_model_for_clm,
61+
vision_encoder=vit,
62+
embedding_dim=4,
63+
decoder_bos_token_id=19,
64+
)
65+
init_weights_with_constant(blip2)
66+
blip2.eval()
67+
return blip2
68+
69+
def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids):
70+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q)
71+
init_weights_with_constant(blip2_loss)
72+
local_loss = blip2_loss(
73+
model_output=blip2_output,
74+
blip2=blip2,
75+
input_ids=input_ids,
76+
attention_mask=all_attn_mask,
77+
)
78+
assert_expected(local_loss.total_loss.item(), 5.07517, rtol=0, atol=1e-4)
79+
80+
def test_local_itc_only_loss(
81+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
82+
):
83+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itm=False, enable_itg=False)
84+
init_weights_with_constant(blip2_loss)
85+
local_loss = blip2_loss(
86+
model_output=blip2_output,
87+
blip2=blip2,
88+
input_ids=input_ids,
89+
attention_mask=all_attn_mask,
90+
)
91+
assert_expected(local_loss.total_loss.item(), 1.38629, rtol=0, atol=1e-4)
92+
93+
def test_local_itm_only_loss(
94+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
95+
):
96+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itg=False)
97+
init_weights_with_constant(blip2_loss)
98+
local_loss = blip2_loss(
99+
model_output=blip2_output,
100+
blip2=blip2,
101+
input_ids=input_ids,
102+
attention_mask=all_attn_mask,
103+
)
104+
assert_expected(local_loss.total_loss.item(), 0.69315, rtol=0, atol=1e-4)
105+
106+
def test_local_itg_only_loss(
107+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
108+
):
109+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itm=False)
110+
init_weights_with_constant(blip2_loss)
111+
local_loss = blip2_loss(
112+
model_output=blip2_output,
113+
blip2=blip2,
114+
input_ids=input_ids,
115+
attention_mask=all_attn_mask,
116+
)
117+
assert_expected(local_loss.total_loss.item(), 2.9957, rtol=0, atol=1e-4)
118+
119+
def test_invalid_loss_input(self):
120+
with pytest.raises(ValueError):
121+
Blip2Phase1Loss(
122+
dim_q=dim_q, enable_itc=False, enable_itm=False, enable_itg=False
123+
)
124+
125+
@staticmethod
126+
def _model_worker(
127+
gpu_id: int,
128+
sync_file: str,
129+
world_size: int,
130+
global_batch_size: int,
131+
all_images: torch.Tensor,
132+
all_input_ids: torch.Tensor,
133+
all_attn_mask: torch.Tensor,
134+
blip2_output: Blip2Output,
135+
blip2: nn.Module,
136+
dim_q=dim_q,
137+
):
138+
init_distributed_on_file(
139+
world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
140+
)
141+
assert global_batch_size % world_size == 0
142+
local_batch_size = global_batch_size // world_size
143+
all_attn_mask = torch.ones([4, 4])
144+
145+
# Split inputs across GPUs
146+
local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id)
147+
local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda(
148+
gpu_id
149+
)
150+
local_attn_mask = torch.split(all_attn_mask, local_batch_size)[gpu_id].cuda(
151+
gpu_id
152+
)
153+
assert blip2_output.text_features is not None
154+
assert blip2_output.prediction_scores is not None
155+
local_blip2_output = Blip2Output(
156+
image_embeddings=torch.split(
157+
blip2_output.image_embeddings, local_batch_size
158+
)[gpu_id].cuda(gpu_id),
159+
image_features=torch.split(blip2_output.image_features, local_batch_size)[
160+
gpu_id
161+
].cuda(gpu_id),
162+
image_qformer_output=torch.split(
163+
blip2_output.image_qformer_output, local_batch_size
164+
)[gpu_id].cuda(gpu_id),
165+
text_features=torch.split(blip2_output.text_features, local_batch_size)[
166+
gpu_id
167+
].cuda(gpu_id),
168+
prediction_scores=torch.split(
169+
blip2_output.prediction_scores, local_batch_size
170+
)[gpu_id].cuda(gpu_id),
171+
)
172+
173+
blip2 = blip2.cuda(gpu_id)
174+
loss_fn = Blip2Phase1Loss(dim_q=dim_q)
175+
init_weights_with_constant(loss_fn)
176+
loss_fn = loss_fn.cuda(gpu_id)
177+
178+
all_params = chain(blip2.parameters(), loss_fn.parameters())
179+
180+
optimizer = optim.SGD(all_params, lr=1e-4)
181+
182+
# Forward pass
183+
loss = loss_fn(
184+
model_output=local_blip2_output,
185+
blip2=blip2,
186+
images=local_images,
187+
input_ids=local_input_ids,
188+
attention_mask=local_attn_mask,
189+
).total_loss
190+
191+
# Compute gradients
192+
optimizer.zero_grad()
193+
loss.backward()
194+
195+
# Gather gradients from all devices
196+
def gather_grads(x: torch.Tensor) -> torch.Tensor:
197+
grads = [torch.zeros_like(x).cuda(gpu_id) for i in range(world_size)]
198+
dist.all_gather(grads, x)
199+
grad = torch.stack(grads).mean()
200+
return grad
201+
202+
# Gather losses from all devices
203+
gathered_loss = gather_grads(torch.Tensor([loss]).cuda(gpu_id))
204+
assert_expected(gathered_loss.item(), 5.07517, rtol=0, atol=1e-4)
205+
206+
@gpu_test(gpu_count=1)
207+
def test_single_gpu_loss(
208+
self,
209+
global_batch_size,
210+
input_ids,
211+
blip2_output,
212+
blip2,
213+
attn_mask,
214+
dim_q,
215+
):
216+
with with_temp_files(count=1) as sync_file:
217+
world_size = 1
218+
mp.spawn(
219+
TestBLIP2Stage1Loss._model_worker,
220+
(
221+
sync_file,
222+
world_size,
223+
global_batch_size,
224+
input_ids,
225+
attn_mask,
226+
blip2_output,
227+
blip2,
228+
dim_q,
229+
),
230+
nprocs=world_size,
231+
)
232+
233+
@gpu_test(gpu_count=2)
234+
def test_multi_gpu_loss(
235+
self,
236+
global_batch_size,
237+
input_ids,
238+
blip2_output,
239+
blip2,
240+
attn_mask,
241+
dim_q,
242+
):
243+
with with_temp_files(count=1) as sync_file:
244+
world_size = 2
245+
mp.spawn(
246+
TestBLIP2Stage1Loss._model_worker,
247+
(
248+
sync_file,
249+
world_size,
250+
global_batch_size,
251+
input_ids,
252+
attn_mask,
253+
blip2_output,
254+
blip2,
255+
dim_q,
256+
),
257+
nprocs=world_size,
258+
)

0 commit comments

Comments
 (0)