|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from itertools import chain |
| 8 | + |
| 9 | +import pytest |
| 10 | +import torch |
| 11 | +from tests.test_utils import ( |
| 12 | + assert_expected, |
| 13 | + gpu_test, |
| 14 | + init_distributed_on_file, |
| 15 | + init_weights_with_constant, |
| 16 | + with_temp_files, |
| 17 | +) |
| 18 | +from torch import distributed as dist, multiprocessing as mp, nn, optim |
| 19 | +from torchmultimodal.models.blip2.blip2 import BLIP2, Blip2Output |
| 20 | +from torchmultimodal.modules.losses.blip2_losses import Blip2Phase1Loss |
| 21 | + |
| 22 | + |
| 23 | +@pytest.fixture |
| 24 | +def dim_q(): |
| 25 | + return 4 |
| 26 | + |
| 27 | + |
| 28 | +class TestBLIP2Stage1Loss: |
| 29 | + @pytest.fixture |
| 30 | + def images(self): |
| 31 | + return torch.ones(4, 3, 2, 2) |
| 32 | + |
| 33 | + @pytest.fixture |
| 34 | + def input_ids(self): |
| 35 | + return torch.ones(4, 4).long() |
| 36 | + |
| 37 | + @pytest.fixture |
| 38 | + def all_attn_mask(self): |
| 39 | + return torch.ones([4, 4]) |
| 40 | + |
| 41 | + @pytest.fixture |
| 42 | + def global_batch_size(self): |
| 43 | + return 4 |
| 44 | + |
| 45 | + @pytest.fixture |
| 46 | + def blip2_output(self): |
| 47 | + return Blip2Output( |
| 48 | + image_embeddings=torch.ones([4, 5, 2]), |
| 49 | + image_features=torch.ones([4, 32, 4]) * 0.5, |
| 50 | + image_qformer_output=torch.ones([4, 32, 4]) * 0.5, |
| 51 | + text_features=torch.ones([4, 4]) * 0.5, |
| 52 | + prediction_scores=torch.ones([4, 4, 20]) * 5, |
| 53 | + ) |
| 54 | + |
| 55 | + @pytest.fixture |
| 56 | + def blip2(self, dim_q, dim_kv, qformer_model_for_clm, vit): |
| 57 | + blip2 = BLIP2( |
| 58 | + dim_q=dim_q, |
| 59 | + image_encoder_embedding_dim=dim_kv, |
| 60 | + qformer=qformer_model_for_clm, |
| 61 | + vision_encoder=vit, |
| 62 | + embedding_dim=4, |
| 63 | + decoder_bos_token_id=19, |
| 64 | + ) |
| 65 | + init_weights_with_constant(blip2) |
| 66 | + blip2.eval() |
| 67 | + return blip2 |
| 68 | + |
| 69 | + def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids): |
| 70 | + blip2_loss = Blip2Phase1Loss(dim_q=dim_q) |
| 71 | + init_weights_with_constant(blip2_loss) |
| 72 | + local_loss = blip2_loss( |
| 73 | + model_output=blip2_output, |
| 74 | + blip2=blip2, |
| 75 | + input_ids=input_ids, |
| 76 | + attention_mask=all_attn_mask, |
| 77 | + ) |
| 78 | + assert_expected(local_loss.total_loss.item(), 5.07517, rtol=0, atol=1e-4) |
| 79 | + |
| 80 | + def test_local_itc_only_loss( |
| 81 | + self, all_attn_mask, blip2_output, blip2, dim_q, input_ids |
| 82 | + ): |
| 83 | + blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itm=False, enable_itg=False) |
| 84 | + init_weights_with_constant(blip2_loss) |
| 85 | + local_loss = blip2_loss( |
| 86 | + model_output=blip2_output, |
| 87 | + blip2=blip2, |
| 88 | + input_ids=input_ids, |
| 89 | + attention_mask=all_attn_mask, |
| 90 | + ) |
| 91 | + assert_expected(local_loss.total_loss.item(), 1.38629, rtol=0, atol=1e-4) |
| 92 | + |
| 93 | + def test_local_itm_only_loss( |
| 94 | + self, all_attn_mask, blip2_output, blip2, dim_q, input_ids |
| 95 | + ): |
| 96 | + blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itg=False) |
| 97 | + init_weights_with_constant(blip2_loss) |
| 98 | + local_loss = blip2_loss( |
| 99 | + model_output=blip2_output, |
| 100 | + blip2=blip2, |
| 101 | + input_ids=input_ids, |
| 102 | + attention_mask=all_attn_mask, |
| 103 | + ) |
| 104 | + assert_expected(local_loss.total_loss.item(), 0.69315, rtol=0, atol=1e-4) |
| 105 | + |
| 106 | + def test_local_itg_only_loss( |
| 107 | + self, all_attn_mask, blip2_output, blip2, dim_q, input_ids |
| 108 | + ): |
| 109 | + blip2_loss = Blip2Phase1Loss(dim_q=dim_q, enable_itc=False, enable_itm=False) |
| 110 | + init_weights_with_constant(blip2_loss) |
| 111 | + local_loss = blip2_loss( |
| 112 | + model_output=blip2_output, |
| 113 | + blip2=blip2, |
| 114 | + input_ids=input_ids, |
| 115 | + attention_mask=all_attn_mask, |
| 116 | + ) |
| 117 | + assert_expected(local_loss.total_loss.item(), 2.9957, rtol=0, atol=1e-4) |
| 118 | + |
| 119 | + def test_invalid_loss_input(self): |
| 120 | + with pytest.raises(ValueError): |
| 121 | + Blip2Phase1Loss( |
| 122 | + dim_q=dim_q, enable_itc=False, enable_itm=False, enable_itg=False |
| 123 | + ) |
| 124 | + |
| 125 | + @staticmethod |
| 126 | + def _model_worker( |
| 127 | + gpu_id: int, |
| 128 | + sync_file: str, |
| 129 | + world_size: int, |
| 130 | + global_batch_size: int, |
| 131 | + all_images: torch.Tensor, |
| 132 | + all_input_ids: torch.Tensor, |
| 133 | + all_attn_mask: torch.Tensor, |
| 134 | + blip2_output: Blip2Output, |
| 135 | + blip2: nn.Module, |
| 136 | + dim_q=dim_q, |
| 137 | + ): |
| 138 | + init_distributed_on_file( |
| 139 | + world_size=world_size, gpu_id=gpu_id, sync_file=sync_file |
| 140 | + ) |
| 141 | + assert global_batch_size % world_size == 0 |
| 142 | + local_batch_size = global_batch_size // world_size |
| 143 | + all_attn_mask = torch.ones([4, 4]) |
| 144 | + |
| 145 | + # Split inputs across GPUs |
| 146 | + local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id) |
| 147 | + local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda( |
| 148 | + gpu_id |
| 149 | + ) |
| 150 | + local_attn_mask = torch.split(all_attn_mask, local_batch_size)[gpu_id].cuda( |
| 151 | + gpu_id |
| 152 | + ) |
| 153 | + assert blip2_output.text_features is not None |
| 154 | + assert blip2_output.prediction_scores is not None |
| 155 | + local_blip2_output = Blip2Output( |
| 156 | + image_embeddings=torch.split( |
| 157 | + blip2_output.image_embeddings, local_batch_size |
| 158 | + )[gpu_id].cuda(gpu_id), |
| 159 | + image_features=torch.split(blip2_output.image_features, local_batch_size)[ |
| 160 | + gpu_id |
| 161 | + ].cuda(gpu_id), |
| 162 | + image_qformer_output=torch.split( |
| 163 | + blip2_output.image_qformer_output, local_batch_size |
| 164 | + )[gpu_id].cuda(gpu_id), |
| 165 | + text_features=torch.split(blip2_output.text_features, local_batch_size)[ |
| 166 | + gpu_id |
| 167 | + ].cuda(gpu_id), |
| 168 | + prediction_scores=torch.split( |
| 169 | + blip2_output.prediction_scores, local_batch_size |
| 170 | + )[gpu_id].cuda(gpu_id), |
| 171 | + ) |
| 172 | + |
| 173 | + blip2 = blip2.cuda(gpu_id) |
| 174 | + loss_fn = Blip2Phase1Loss(dim_q=dim_q) |
| 175 | + init_weights_with_constant(loss_fn) |
| 176 | + loss_fn = loss_fn.cuda(gpu_id) |
| 177 | + |
| 178 | + all_params = chain(blip2.parameters(), loss_fn.parameters()) |
| 179 | + |
| 180 | + optimizer = optim.SGD(all_params, lr=1e-4) |
| 181 | + |
| 182 | + # Forward pass |
| 183 | + loss = loss_fn( |
| 184 | + model_output=local_blip2_output, |
| 185 | + blip2=blip2, |
| 186 | + images=local_images, |
| 187 | + input_ids=local_input_ids, |
| 188 | + attention_mask=local_attn_mask, |
| 189 | + ).total_loss |
| 190 | + |
| 191 | + # Compute gradients |
| 192 | + optimizer.zero_grad() |
| 193 | + loss.backward() |
| 194 | + |
| 195 | + # Gather gradients from all devices |
| 196 | + def gather_grads(x: torch.Tensor) -> torch.Tensor: |
| 197 | + grads = [torch.zeros_like(x).cuda(gpu_id) for i in range(world_size)] |
| 198 | + dist.all_gather(grads, x) |
| 199 | + grad = torch.stack(grads).mean() |
| 200 | + return grad |
| 201 | + |
| 202 | + # Gather losses from all devices |
| 203 | + gathered_loss = gather_grads(torch.Tensor([loss]).cuda(gpu_id)) |
| 204 | + assert_expected(gathered_loss.item(), 5.07517, rtol=0, atol=1e-4) |
| 205 | + |
| 206 | + @gpu_test(gpu_count=1) |
| 207 | + def test_single_gpu_loss( |
| 208 | + self, |
| 209 | + global_batch_size, |
| 210 | + input_ids, |
| 211 | + blip2_output, |
| 212 | + blip2, |
| 213 | + attn_mask, |
| 214 | + dim_q, |
| 215 | + ): |
| 216 | + with with_temp_files(count=1) as sync_file: |
| 217 | + world_size = 1 |
| 218 | + mp.spawn( |
| 219 | + TestBLIP2Stage1Loss._model_worker, |
| 220 | + ( |
| 221 | + sync_file, |
| 222 | + world_size, |
| 223 | + global_batch_size, |
| 224 | + input_ids, |
| 225 | + attn_mask, |
| 226 | + blip2_output, |
| 227 | + blip2, |
| 228 | + dim_q, |
| 229 | + ), |
| 230 | + nprocs=world_size, |
| 231 | + ) |
| 232 | + |
| 233 | + @gpu_test(gpu_count=2) |
| 234 | + def test_multi_gpu_loss( |
| 235 | + self, |
| 236 | + global_batch_size, |
| 237 | + input_ids, |
| 238 | + blip2_output, |
| 239 | + blip2, |
| 240 | + attn_mask, |
| 241 | + dim_q, |
| 242 | + ): |
| 243 | + with with_temp_files(count=1) as sync_file: |
| 244 | + world_size = 2 |
| 245 | + mp.spawn( |
| 246 | + TestBLIP2Stage1Loss._model_worker, |
| 247 | + ( |
| 248 | + sync_file, |
| 249 | + world_size, |
| 250 | + global_batch_size, |
| 251 | + input_ids, |
| 252 | + attn_mask, |
| 253 | + blip2_output, |
| 254 | + blip2, |
| 255 | + dim_q, |
| 256 | + ), |
| 257 | + nprocs=world_size, |
| 258 | + ) |
0 commit comments