Skip to content

Commit 476d23b

Browse files
committed
[FLAVA]Separate the pretraining loss from the pretraininig model
ghstack-source-id: 417f074 Pull Request resolved: #278
1 parent 06d3436 commit 476d23b

File tree

4 files changed

+153
-16
lines changed

4 files changed

+153
-16
lines changed

examples/flava/model.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88

99
import torch
1010
from pytorch_lightning import LightningModule
11+
from torch import nn
1112
from torchmetrics import Accuracy
1213
from torchmultimodal.models.flava.model import (
1314
flava_model_for_classification,
1415
flava_model_for_pretraining,
1516
)
17+
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
1618
from transformers.optimization import get_cosine_schedule_with_warmup
1719

1820

1921
def get_optimizers_for_lightning(
20-
model: torch.nn.Module,
22+
model: nn.Module,
2123
learning_rate: float,
2224
adam_eps: float,
2325
adam_weight_decay: float,
@@ -59,6 +61,7 @@ def __init__(
5961
self.adam_weight_decay = adam_weight_decay
6062
self.warmup_steps = warmup_steps
6163
self.max_steps = max_steps
64+
self.loss = FLAVAPretrainingLoss(logit_scale=self.model.logit_scale)
6265

6366
def training_step(self, batch, batch_idx):
6467
output = self._step(batch, batch_idx)
@@ -104,7 +107,24 @@ def _step(self, batch, batch_idx):
104107
itm_labels=batch.get("itm_labels", None),
105108
required_embedding=required_embedding,
106109
)
107-
return output
110+
111+
loss = self.loss(
112+
multimodal_masked_sequence=output.multimodal_masked_sequence,
113+
pos_mask=output.pos_mask,
114+
itm_labels=output.itm_labels,
115+
mim_labels=output.mim_labels,
116+
mlm_labels=output.mlm_labels,
117+
mmm_mlm_labels=output.mmm_mlm_labels,
118+
mmm_mim_labels=output.mmm_mim_labels,
119+
projected_image_embeddings=output.projected_image_embeddings,
120+
projected_text_embeddings=output.projected_text_embeddings,
121+
itm_logits=output.itm_logits,
122+
mlm_head_output=output.mlm_head_output,
123+
mim_head_output=output.mim_head_output,
124+
mmm_mlm_head_output=output.mmm_mlm_head_output,
125+
mmm_mim_head_output=output.mmm_mim_head_output,
126+
)
127+
return loss
108128

109129
def configure_optimizers(self):
110130
return get_optimizers_for_lightning(

test/models/flava/test_checkpoint.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
flava_model_for_classification,
1313
flava_model_for_pretraining,
1414
)
15+
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
1516

1617

1718
@pytest.fixture(autouse=True)
@@ -139,8 +140,25 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
139140
image_input = inputs_pretraining("image")
140141
text_input = inputs_pretraining("text")
141142
flava = pretraining_model()
142-
143+
losses = FLAVAPretrainingLoss(flava.logit_scale)
143144
output = flava(*mm_input)
145+
output = losses(
146+
multimodal_masked_sequence=output.multimodal_masked_sequence,
147+
pos_mask=output.pos_mask,
148+
itm_labels=output.itm_labels,
149+
mim_labels=output.mim_labels,
150+
mlm_labels=output.mlm_labels,
151+
mmm_mlm_labels=output.mmm_mlm_labels,
152+
mmm_mim_labels=output.mmm_mim_labels,
153+
projected_image_embeddings=output.projected_image_embeddings,
154+
projected_text_embeddings=output.projected_text_embeddings,
155+
itm_logits=output.itm_logits,
156+
mlm_head_output=output.mlm_head_output,
157+
mim_head_output=output.mim_head_output,
158+
mmm_mlm_head_output=output.mmm_mlm_head_output,
159+
mmm_mim_head_output=output.mmm_mim_head_output,
160+
)
161+
144162
actual = output.losses
145163
expected = dict(
146164
mmm_text_loss=10.9567,
@@ -153,6 +171,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
153171
self._assert_tensor_dicts_equal(actual, expected)
154172

155173
output = flava(*image_input)
174+
output = losses(
175+
multimodal_masked_sequence=output.multimodal_masked_sequence,
176+
pos_mask=output.pos_mask,
177+
itm_labels=output.itm_labels,
178+
mim_labels=output.mim_labels,
179+
mlm_labels=output.mlm_labels,
180+
mmm_mlm_labels=output.mmm_mlm_labels,
181+
mmm_mim_labels=output.mmm_mim_labels,
182+
projected_image_embeddings=output.projected_image_embeddings,
183+
projected_text_embeddings=output.projected_text_embeddings,
184+
itm_logits=output.itm_logits,
185+
mlm_head_output=output.mlm_head_output,
186+
mim_head_output=output.mim_head_output,
187+
mmm_mlm_head_output=output.mmm_mlm_head_output,
188+
mmm_mim_head_output=output.mmm_mim_head_output,
189+
)
156190
actual = output.losses
157191
expected = dict(
158192
mmm_text_loss=None,
@@ -165,6 +199,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
165199
self._assert_tensor_dicts_equal(actual, expected)
166200

167201
output = flava(*text_input)
202+
output = losses(
203+
multimodal_masked_sequence=output.multimodal_masked_sequence,
204+
pos_mask=output.pos_mask,
205+
itm_labels=output.itm_labels,
206+
mim_labels=output.mim_labels,
207+
mlm_labels=output.mlm_labels,
208+
mmm_mlm_labels=output.mmm_mlm_labels,
209+
mmm_mim_labels=output.mmm_mim_labels,
210+
projected_image_embeddings=output.projected_image_embeddings,
211+
projected_text_embeddings=output.projected_text_embeddings,
212+
itm_logits=output.itm_logits,
213+
mlm_head_output=output.mlm_head_output,
214+
mim_head_output=output.mim_head_output,
215+
mmm_mlm_head_output=output.mmm_mlm_head_output,
216+
mmm_mim_head_output=output.mmm_mim_head_output,
217+
)
168218
actual = output.losses
169219
expected = dict(
170220
mmm_text_loss=None,

test/models/flava/test_flava.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
FLAVAOutput,
1919
)
2020
from torchmultimodal.modules.layers.transformer import TransformerOutput
21+
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
2122

2223
NUM_CLASSES = 2
2324

@@ -61,6 +62,7 @@ def test_forward_pretraining(self):
6162
mlm_labels[:, 1:3] = text[:, 1:3]
6263
itm_labels = torch.tensor((0, 1), dtype=torch.long)
6364
flava = flava_model_for_pretraining()
65+
losses = FLAVAPretrainingLoss(flava.logit_scale)
6466
flava.eval()
6567
output = flava(
6668
image=image,
@@ -72,6 +74,22 @@ def test_forward_pretraining(self):
7274
itm_labels=itm_labels,
7375
mlm_labels=mlm_labels,
7476
)
77+
output = losses(
78+
multimodal_masked_sequence=output.multimodal_masked_sequence,
79+
pos_mask=output.pos_mask,
80+
itm_labels=output.itm_labels,
81+
mim_labels=output.mim_labels,
82+
mlm_labels=output.mlm_labels,
83+
mmm_mlm_labels=output.mmm_mlm_labels,
84+
mmm_mim_labels=output.mmm_mim_labels,
85+
projected_image_embeddings=output.projected_image_embeddings,
86+
projected_text_embeddings=output.projected_text_embeddings,
87+
itm_logits=output.itm_logits,
88+
mlm_head_output=output.mlm_head_output,
89+
mim_head_output=output.mim_head_output,
90+
mmm_mlm_head_output=output.mmm_mlm_head_output,
91+
mmm_mim_head_output=output.mmm_mim_head_output,
92+
)
7593
self.assertIsNone(output.mlm_output)
7694
self.assertIsNone(output.mim_output)
7795
self.assertIsNotNone(output.global_contrastive_output)
@@ -96,6 +114,22 @@ def test_forward_pretraining(self):
96114
itm_labels=itm_labels,
97115
mlm_labels=mlm_labels,
98116
)
117+
output = losses(
118+
multimodal_masked_sequence=output.multimodal_masked_sequence,
119+
pos_mask=output.pos_mask,
120+
itm_labels=output.itm_labels,
121+
mim_labels=output.mim_labels,
122+
mlm_labels=output.mlm_labels,
123+
mmm_mlm_labels=output.mmm_mlm_labels,
124+
mmm_mim_labels=output.mmm_mim_labels,
125+
projected_image_embeddings=output.projected_image_embeddings,
126+
projected_text_embeddings=output.projected_text_embeddings,
127+
itm_logits=output.itm_logits,
128+
mlm_head_output=output.mlm_head_output,
129+
mim_head_output=output.mim_head_output,
130+
mmm_mlm_head_output=output.mmm_mlm_head_output,
131+
mmm_mim_head_output=output.mmm_mim_head_output,
132+
)
99133
self.assertIsNone(output.mlm_output)
100134
self.assertIsNotNone(output.mim_output)
101135
self.assertIsNone(output.global_contrastive_output)
@@ -120,6 +154,22 @@ def test_forward_pretraining(self):
120154
itm_labels=itm_labels,
121155
mlm_labels=mlm_labels,
122156
)
157+
output = losses(
158+
multimodal_masked_sequence=output.multimodal_masked_sequence,
159+
pos_mask=output.pos_mask,
160+
itm_labels=output.itm_labels,
161+
mim_labels=output.mim_labels,
162+
mlm_labels=output.mlm_labels,
163+
mmm_mlm_labels=output.mmm_mlm_labels,
164+
mmm_mim_labels=output.mmm_mim_labels,
165+
projected_image_embeddings=output.projected_image_embeddings,
166+
projected_text_embeddings=output.projected_text_embeddings,
167+
itm_logits=output.itm_logits,
168+
mlm_head_output=output.mlm_head_output,
169+
mim_head_output=output.mim_head_output,
170+
mmm_mlm_head_output=output.mmm_mlm_head_output,
171+
mmm_mim_head_output=output.mmm_mim_head_output,
172+
)
123173
self.assertIsNotNone(output.mlm_output)
124174
self.assertIsNone(output.mim_output)
125175
self.assertIsNone(output.global_contrastive_output)

torchmultimodal/models/flava/model.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424
TransformerEncoder,
2525
TransformerOutput,
2626
)
27-
from torchmultimodal.modules.losses.flava import (
28-
FLAVAPretrainingLoss,
29-
FLAVAPretrainingLossOutput,
30-
Pooler,
31-
)
27+
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput, Pooler
3228
from torchmultimodal.utils.common import ModelOutput, PretrainedMixin
3329
from typing_extensions import Literal
3430

@@ -62,7 +58,7 @@
6258
FLAVA_FOR_PRETRAINED_MAPPING = {
6359
# This will no longer load with the updated model, but keeping here just in case
6460
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
65-
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
61+
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_no_loss.pt",
6662
}
6763

6864
FLAVA_MODEL_MAPPING = {
@@ -105,6 +101,24 @@ class FLAVAForClassificationOutput(ModelOutput):
105101
loss: Tensor
106102

107103

104+
@dataclass
105+
class FLAVAForPretrainingOutput:
106+
multimodal_masked_sequence: Tensor
107+
pos_mask: Tensor
108+
mim_labels: Tensor
109+
mlm_labels: Tensor
110+
mmm_mlm_labels: Tensor
111+
mmm_mim_labels: Tensor
112+
itm_labels: Tensor
113+
projected_image_embeddings: Tensor
114+
projected_text_embeddings: Tensor
115+
itm_logits: Tensor
116+
mlm_head_output: Tensor
117+
mim_head_output: Tensor
118+
mmm_mlm_head_output: Tensor
119+
mmm_mim_head_output: Tensor
120+
121+
108122
class FLAVAModel(nn.Module, PretrainedMixin):
109123
def __init__(
110124
self,
@@ -367,22 +381,22 @@ def __init__(
367381
self,
368382
model: FLAVAModel,
369383
image_codebook: nn.Module,
370-
loss: nn.Module,
371384
itm_head: nn.Module,
372385
mlm_head: nn.Module,
373386
mim_head: nn.Module,
374387
mmm_mlm_head: nn.Module,
375388
mmm_mim_head: nn.Module,
389+
logit_scale: nn.Module,
376390
):
377391
super().__init__()
378392
self.model = model
379393
self.image_codebook = image_codebook
380-
self.loss = loss
381394
self.itm_head = itm_head
382395
self.mlm_head = mlm_head
383396
self.mim_head = mim_head
384397
self.mmm_mlm_head = mmm_mlm_head
385398
self.mmm_mim_head = mmm_mim_head
399+
self.logit_scale = logit_scale
386400

387401
def encode_image(
388402
self,
@@ -469,8 +483,10 @@ def forward(
469483
itm_logits = self.itm_head(multimodal_masked_sequence)
470484

471485
multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
486+
472487
if mlm_labels is not None:
473488
mmm_mlm_labels = mlm_labels[pos_mask]
489+
474490
if image_labels is not None:
475491
mmm_mim_labels = image_labels[pos_mask]
476492

@@ -494,14 +510,14 @@ def forward(
494510
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
495511
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)
496512

497-
return self.loss(
513+
return FLAVAForPretrainingOutput(
498514
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
499515
pos_mask=pos_mask,
500-
itm_labels=itm_labels,
501516
mim_labels=image_labels,
502517
mlm_labels=mlm_labels,
503-
mmm_mlm_labels=mmm_mlm_labels,
504518
mmm_mim_labels=mmm_mim_labels,
519+
mmm_mlm_labels=mmm_mlm_labels,
520+
itm_labels=itm_labels,
505521
projected_image_embeddings=flava_output.projected_image_embeddings,
506522
projected_text_embeddings=flava_output.projected_text_embeddings,
507523
itm_logits=itm_logits,
@@ -660,6 +676,7 @@ def flava_model_for_pretraining(
660676
codebook_image_size: int = 112,
661677
pretrained_model_key: Optional[str] = None,
662678
image_vocab_size: int = 8192,
679+
logit_scale: float = math.log(1 / 0.07),
663680
**flava_model_kwargs: Any,
664681
# TODO: Add parameters for loss here
665682
) -> FLAVAForPreTraining:
@@ -677,18 +694,18 @@ def flava_model_for_pretraining(
677694
mmm_mim_head = MaskedPredictionHead(
678695
hidden_size=hidden_size, vocab_size=image_vocab_size
679696
)
680-
losses = FLAVAPretrainingLoss()
697+
681698
codebook = DalleVAEEncoder(image_size=codebook_image_size)
682699

683700
flava = FLAVAForPreTraining(
684701
model=model,
685702
image_codebook=codebook,
686-
loss=losses,
687703
itm_head=itm_head,
688704
mlm_head=mlm_head,
689705
mim_head=mim_head,
690706
mmm_mlm_head=mmm_mlm_head,
691707
mmm_mim_head=mmm_mim_head,
708+
logit_scale=nn.Parameter(logit_scale * torch.ones([])),
692709
)
693710

694711
if pretrained_model_key is not None:

0 commit comments

Comments
 (0)