Skip to content

Commit e3ebd29

Browse files
Peng Chenfacebook-github-bot
authored andcommitted
add blip2 loss under torchmultimodal/modules/losses (#485)
Summary: Pull Request resolved: #485 as title Differential Revision: D50148648 fbshipit-source-id: 5c8444535286afa6be2c366a75da185ad45513ef
1 parent f1c891c commit e3ebd29

File tree

2 files changed

+604
-0
lines changed

2 files changed

+604
-0
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import chain
8+
9+
import pytest
10+
import torch
11+
from tests.test_utils import (
12+
assert_expected,
13+
gpu_test,
14+
init_distributed_on_file,
15+
init_weights_with_constant,
16+
with_temp_files,
17+
)
18+
from torch import distributed as dist, multiprocessing as mp, nn, optim
19+
from torchmultimodal.models.blip2.blip2 import Blip2Output
20+
from torchmultimodal.modules.losses.blip2_losses import Blip2Phase1Loss
21+
22+
23+
@pytest.fixture
24+
def dim_q():
25+
return 4
26+
27+
28+
class TestBLIP2Stage1Loss:
29+
@pytest.fixture
30+
def images(self):
31+
return torch.ones(4, 3, 2, 2)
32+
33+
@pytest.fixture
34+
def input_ids(self):
35+
return torch.ones(4, 4).long()
36+
37+
@pytest.fixture
38+
def all_attn_mask(self):
39+
return torch.ones([4, 4])
40+
41+
@pytest.fixture
42+
def global_batch_size(self):
43+
return 4
44+
45+
@pytest.fixture
46+
def blip2_output(self):
47+
return Blip2Output(
48+
image_embeddings=torch.ones([4, 5, 2]),
49+
image_features=torch.ones([4, 32, 4]) * 0.5,
50+
image_qformer_output=torch.ones([4, 32, 4]) * 0.5,
51+
text_features=torch.ones([4, 4]) * 0.5,
52+
prediction_scores=torch.ones([4, 4, 20]) * 5,
53+
)
54+
55+
def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids):
56+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q)
57+
init_weights_with_constant(blip2_loss)
58+
local_loss = blip2_loss(
59+
model_output=blip2_output,
60+
blip2=blip2,
61+
input_ids=input_ids,
62+
attention_mask=all_attn_mask,
63+
)
64+
assert_expected(local_loss.total_loss.item(), 5.07517, rtol=0, atol=1e-4)
65+
66+
def test_local_itc_only_loss(
67+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
68+
):
69+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itm=False, enable_itg=False)
70+
init_weights_with_constant(blip2_loss)
71+
local_loss = blip2_loss(
72+
model_output=blip2_output,
73+
blip2=blip2,
74+
input_ids=input_ids,
75+
attention_mask=all_attn_mask,
76+
)
77+
assert_expected(local_loss.total_loss.item(), 1.38629, rtol=0, atol=1e-4)
78+
79+
def test_local_itm_only_loss(
80+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
81+
):
82+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itg=False)
83+
init_weights_with_constant(blip2_loss)
84+
local_loss = blip2_loss(
85+
model_output=blip2_output,
86+
blip2=blip2,
87+
input_ids=input_ids,
88+
attention_mask=all_attn_mask,
89+
)
90+
assert_expected(local_loss.total_loss.item(), 0.69315, rtol=0, atol=1e-4)
91+
92+
def test_local_itg_only_loss(
93+
self, all_attn_mask, blip2_output, blip2, dim_q, input_ids
94+
):
95+
blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itm=False)
96+
init_weights_with_constant(blip2_loss)
97+
local_loss = blip2_loss(
98+
model_output=blip2_output,
99+
blip2=blip2,
100+
input_ids=input_ids,
101+
attention_mask=all_attn_mask,
102+
)
103+
assert_expected(local_loss.total_loss.item(), 2.9957, rtol=0, atol=1e-4)
104+
105+
def test_invalid_loss_input(self):
106+
with pytest.raises(ValueError):
107+
Blip2Phase1Loss(
108+
dim_q=dim_q, enable_itc=False, enable_itm=False, enable_itg=False
109+
)
110+
111+
@staticmethod
112+
def _model_worker(
113+
gpu_id: int,
114+
sync_file: str,
115+
world_size: int,
116+
global_batch_size: int,
117+
all_images: torch.Tensor,
118+
all_input_ids: torch.Tensor,
119+
all_attn_mask: torch.Tensor,
120+
blip2_output: Blip2Output,
121+
blip2: nn.Module,
122+
dim_q=dim_q,
123+
):
124+
init_distributed_on_file(
125+
world_size=world_size, gpu_id=gpu_id, sync_file=sync_file
126+
)
127+
assert global_batch_size % world_size == 0
128+
local_batch_size = global_batch_size // world_size
129+
all_attn_mask = torch.ones([4, 4])
130+
131+
# Split inputs across GPUs
132+
local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id)
133+
local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda(
134+
gpu_id
135+
)
136+
local_attn_mask = torch.split(all_attn_mask, local_batch_size)[gpu_id].cuda(
137+
gpu_id
138+
)
139+
assert blip2_output.text_features is not None
140+
assert blip2_output.prediction_scores is not None
141+
local_blip2_output = Blip2Output(
142+
image_embeddings=torch.split(
143+
blip2_output.image_embeddings, local_batch_size
144+
)[gpu_id].cuda(gpu_id),
145+
image_features=torch.split(blip2_output.image_features, local_batch_size)[
146+
gpu_id
147+
].cuda(gpu_id),
148+
image_qformer_output=torch.split(
149+
blip2_output.image_qformer_output, local_batch_size
150+
)[gpu_id].cuda(gpu_id),
151+
text_features=torch.split(blip2_output.text_features, local_batch_size)[
152+
gpu_id
153+
].cuda(gpu_id),
154+
prediction_scores=torch.split(
155+
blip2_output.prediction_scores, local_batch_size
156+
)[gpu_id].cuda(gpu_id),
157+
)
158+
159+
blip2 = blip2.cuda(gpu_id)
160+
loss_fn = Blip2Phase1Loss(dim_q=dim_q)
161+
init_weights_with_constant(loss_fn)
162+
loss_fn = loss_fn.cuda(gpu_id)
163+
164+
all_params = chain(blip2.parameters(), loss_fn.parameters())
165+
166+
optimizer = optim.SGD(all_params, lr=1e-4)
167+
168+
# Forward pass
169+
loss = loss_fn(
170+
model_output=local_blip2_output,
171+
blip2=blip2,
172+
images=local_images,
173+
input_ids=local_input_ids,
174+
attention_mask=local_attn_mask,
175+
).total_loss
176+
177+
# Compute gradients
178+
optimizer.zero_grad()
179+
loss.backward()
180+
181+
# Gather gradients from all devices
182+
def gather_grads(x: torch.Tensor) -> torch.Tensor:
183+
grads = [torch.zeros_like(x).cuda(gpu_id) for i in range(world_size)]
184+
dist.all_gather(grads, x)
185+
grad = torch.stack(grads).mean()
186+
return grad
187+
188+
# Gather losses from all devices
189+
gathered_loss = gather_grads(torch.Tensor([loss]).cuda(gpu_id))
190+
assert_expected(gathered_loss.item(), 5.07517, rtol=0, atol=1e-4)
191+
192+
@gpu_test(gpu_count=1)
193+
def test_single_gpu_loss(
194+
self,
195+
global_batch_size,
196+
input_ids,
197+
blip2_output,
198+
blip2,
199+
attn_mask,
200+
dim_q,
201+
):
202+
with with_temp_files(count=1) as sync_file:
203+
world_size = 1
204+
mp.spawn(
205+
TestBLIP2Stage1Loss._model_worker,
206+
(
207+
sync_file,
208+
world_size,
209+
global_batch_size,
210+
input_ids,
211+
attn_mask,
212+
blip2_output,
213+
blip2,
214+
dim_q,
215+
),
216+
nprocs=world_size,
217+
)
218+
219+
@gpu_test(gpu_count=2)
220+
def test_multi_gpu_loss(
221+
self,
222+
global_batch_size,
223+
input_ids,
224+
blip2_output,
225+
blip2,
226+
attn_mask,
227+
dim_q,
228+
):
229+
with with_temp_files(count=1) as sync_file:
230+
world_size = 2
231+
mp.spawn(
232+
TestBLIP2Stage1Loss._model_worker,
233+
(
234+
sync_file,
235+
world_size,
236+
global_batch_size,
237+
input_ids,
238+
attn_mask,
239+
blip2_output,
240+
blip2,
241+
dim_q,
242+
),
243+
nprocs=world_size,
244+
)

0 commit comments

Comments
 (0)