Skip to content

Commit fdba1c9

Browse files
committed
[FLAVA]Separate the pretraining loss from the pretraininig model
ghstack-source-id: 020d925 Pull Request resolved: #278
1 parent ff9b47c commit fdba1c9

File tree

3 files changed

+65
-18
lines changed

3 files changed

+65
-18
lines changed

examples/flava/model.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Tuple
7+
from typing import Any, List, Tuple
88

99
import torch
1010
from pytorch_lightning import LightningModule
11+
from torch import Tensor
1112
from torchmetrics import Accuracy
1213
from torchmultimodal.models.flava.model import (
1314
flava_model_for_classification,
1415
flava_model_for_pretraining,
1516
)
17+
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
1618
from transformers.optimization import get_cosine_schedule_with_warmup
1719

1820

1921
def get_optimizers_for_lightning(
20-
model: torch.nn.Module,
22+
parameters: List[Tensor],
2123
learning_rate: float,
2224
adam_eps: float,
2325
adam_weight_decay: float,
@@ -26,7 +28,7 @@ def get_optimizers_for_lightning(
2628
max_steps: int,
2729
):
2830
optimizer = torch.optim.AdamW(
29-
model.parameters(),
31+
parameters,
3032
lr=learning_rate,
3133
betas=adam_betas,
3234
eps=adam_eps,
@@ -59,6 +61,7 @@ def __init__(
5961
self.adam_weight_decay = adam_weight_decay
6062
self.warmup_steps = warmup_steps
6163
self.max_steps = max_steps
64+
self.loss = FLAVAPretrainingLoss()
6265

6366
def training_step(self, batch, batch_idx):
6467
output = self._step(batch, batch_idx)
@@ -104,11 +107,29 @@ def _step(self, batch, batch_idx):
104107
itm_labels=batch.get("itm_labels", None),
105108
required_embedding=required_embedding,
106109
)
107-
return output
110+
111+
loss = self.loss(
112+
multimodal_masked_sequence=output.multimodal_masked_sequence,
113+
pos_mask=output.pos_mask,
114+
itm_labels=output.itm_labels,
115+
mim_labels=output.mim_labels,
116+
mlm_labels=output.mlm_labels,
117+
mmm_mlm_labels=output.mmm_mlm_labels,
118+
mmm_mim_labels=output.mmm_mim_labels,
119+
projected_image_embeddings=output.projected_image_embeddings,
120+
projected_text_embeddings=output.projected_text_embeddings,
121+
itm_logits=output.itm_logits,
122+
mlm_head_output=output.mlm_head_output,
123+
mim_head_output=output.mim_head_output,
124+
mmm_mlm_head_output=output.mmm_mlm_head_output,
125+
mmm_mim_head_output=output.mmm_mim_head_output,
126+
)
127+
return loss
108128

109129
def configure_optimizers(self):
130+
parameters = self.model.parameters() + self.loss.parameters()
110131
return get_optimizers_for_lightning(
111-
self.model,
132+
parameters,
112133
self.learning_rate,
113134
self.adam_eps,
114135
self.adam_weight_decay,
@@ -194,7 +215,7 @@ def _step(self, batch, batch_idx):
194215

195216
def configure_optimizers(self):
196217
return get_optimizers_for_lightning(
197-
self.model,
218+
self.model.parameters(),
198219
self.learning_rate,
199220
self.adam_eps,
200221
self.adam_weight_decay,

torchmultimodal/models/flava/model.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,24 @@ class FLAVAForClassificationOutput(ModelOutput):
105105
loss: Tensor
106106

107107

108+
@dataclass
109+
class FLAVAForPretrainingOutput:
110+
multimodal_masked_sequence: Tensor
111+
pos_mask: Tensor
112+
mim_labels: Tensor
113+
mlm_labels: Tensor
114+
mmm_mlm_labels: Tensor
115+
mmm_mim_labels: Tensor
116+
itm_labels: Tensor
117+
projected_image_embeddings: Tensor
118+
projected_text_embeddings: Tensor
119+
itm_logits: Tensor
120+
mlm_head_output: Tensor
121+
mim_head_output: Tensor
122+
mmm_mlm_head_output: Tensor
123+
mmm_mim_head_output: Tensor
124+
125+
108126
class FLAVAModel(nn.Module, PretrainedMixin):
109127
def __init__(
110128
self,
@@ -452,6 +470,8 @@ def forward(
452470
text_masked_sequence[:, start_index:, :], mlm_labels
453471
)
454472

473+
mmm_mlm_labels = mlm_labels
474+
mmm_mim_labels = image_labels
455475
if multimodal_masked_sequence is not None:
456476
if itm_labels is not None:
457477
pos_pairs = itm_labels.ne(0)
@@ -466,37 +486,41 @@ def forward(
466486
itm_logits = self.itm_head(multimodal_masked_sequence)
467487

468488
multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
489+
469490
if mlm_labels is not None:
470-
mlm_labels = mlm_labels[pos_mask]
491+
mmm_mlm_labels = mlm_labels[pos_mask]
492+
471493
if image_labels is not None:
472-
image_labels = image_labels[pos_mask]
494+
mmm_mim_labels = image_labels[pos_mask]
473495

474496
if multimodal_masked_sequence is not None:
475497
start_index = (
476-
-mlm_labels.size(1)
477-
if mlm_labels is not None
498+
-mmm_mlm_labels.size(1)
499+
if mmm_mlm_labels is not None
478500
else -(text_masked_sequence.size(1) - 1)
479501
)
480502
sequence_for_text = multimodal_masked_sequence[:, start_index:, :]
481-
mmm_mlm_head_output = self.mmm_mlm_head(sequence_for_text, mlm_labels)
503+
mmm_mlm_head_output = self.mmm_mlm_head(sequence_for_text, mmm_mlm_labels)
482504

483505
if multimodal_masked_sequence is not None:
484506
# Starts from 2 because of 2 CLS, one for multimodal encoder and one
485507
# that comes from image encoder.
486508
total_indices = (
487-
image_labels.size(1)
488-
if image_labels is not None
509+
mmm_mim_labels.size(1)
510+
if mmm_mim_labels is not None
489511
else (image_masked_sequence.size(1) - 1)
490512
)
491513
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
492-
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, image_labels)
514+
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)
493515

494-
return self.loss(
516+
return FLAVAForPretrainingOutput(
495517
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
496518
pos_mask=pos_mask,
497-
itm_labels=itm_labels,
498519
mim_labels=image_labels,
499520
mlm_labels=mlm_labels,
521+
mmm_mim_labels=mmm_mim_labels,
522+
mmm_mlm_labels=mmm_mlm_labels,
523+
itm_labels=itm_labels,
500524
projected_image_embeddings=flava_output.projected_image_embeddings,
501525
projected_text_embeddings=flava_output.projected_text_embeddings,
502526
itm_logits=itm_logits,

torchmultimodal/modules/losses/flava.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def forward(
273273
itm_labels: Optional[Tensor] = None,
274274
mim_labels: Optional[Tensor] = None,
275275
mlm_labels: Optional[Tensor] = None,
276+
mmm_mim_labels: Optional[Tensor] = None,
277+
mmm_mlm_labels: Optional[Tensor] = None,
276278
projected_image_embeddings: Optional[Tensor] = None,
277279
projected_text_embeddings: Optional[Tensor] = None,
278280
itm_logits: Optional[Tensor] = None,
@@ -315,14 +317,14 @@ def forward(
315317

316318
if mmm_mlm_head_output is not None and self.mmm_text_loss_weight > 0:
317319
outputs.mmm_text_output = self.mmm_loss.mlm(
318-
mmm_mlm_head_output, mlm_labels
320+
mmm_mlm_head_output, mmm_mlm_labels
319321
) # type: ignore
320322
outputs.mmm_text_output.loss *= self.mmm_text_loss_weight
321323
outputs.losses.mmm_text_loss = outputs.mmm_text_output.loss
322324

323325
if mmm_mim_head_output is not None and self.mmm_image_loss_weight > 0:
324326
outputs.mmm_image_output = self.mmm_loss.mim(
325-
mmm_mim_head_output, mim_labels
327+
mmm_mim_head_output, mmm_mim_labels
326328
) # type: ignore
327329
outputs.mmm_image_output.loss *= self.mmm_image_loss_weight
328330
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss

0 commit comments

Comments
 (0)