44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from typing import Any , Tuple
7+ from typing import Any , List , Tuple
88
99import torch
1010from pytorch_lightning import LightningModule
11+ from torch import Tensor
1112from torchmetrics import Accuracy
1213from torchmultimodal .models .flava .model import (
1314 flava_model_for_classification ,
1415 flava_model_for_pretraining ,
1516)
17+ from torchmultimodal .modules .losses .flava import FLAVAPretrainingLoss
1618from transformers .optimization import get_cosine_schedule_with_warmup
1719
1820
1921def get_optimizers_for_lightning (
20- model : torch . nn . Module ,
22+ parameters : List [ Tensor ] ,
2123 learning_rate : float ,
2224 adam_eps : float ,
2325 adam_weight_decay : float ,
@@ -26,7 +28,7 @@ def get_optimizers_for_lightning(
2628 max_steps : int ,
2729):
2830 optimizer = torch .optim .AdamW (
29- model . parameters () ,
31+ parameters ,
3032 lr = learning_rate ,
3133 betas = adam_betas ,
3234 eps = adam_eps ,
@@ -59,6 +61,7 @@ def __init__(
5961 self .adam_weight_decay = adam_weight_decay
6062 self .warmup_steps = warmup_steps
6163 self .max_steps = max_steps
64+ self .loss = FLAVAPretrainingLoss ()
6265
6366 def training_step (self , batch , batch_idx ):
6467 output = self ._step (batch , batch_idx )
@@ -104,11 +107,29 @@ def _step(self, batch, batch_idx):
104107 itm_labels = batch .get ("itm_labels" , None ),
105108 required_embedding = required_embedding ,
106109 )
107- return output
110+
111+ loss = self .loss (
112+ multimodal_masked_sequence = output .multimodal_masked_sequence ,
113+ pos_mask = output .pos_mask ,
114+ itm_labels = output .itm_labels ,
115+ mim_labels = output .mim_labels ,
116+ mlm_labels = output .mlm_labels ,
117+ mmm_mlm_labels = output .mmm_mlm_labels ,
118+ mmm_mim_labels = output .mmm_mim_labels ,
119+ projected_image_embeddings = output .projected_image_embeddings ,
120+ projected_text_embeddings = output .projected_text_embeddings ,
121+ itm_logits = output .itm_logits ,
122+ mlm_head_output = output .mlm_head_output ,
123+ mim_head_output = output .mim_head_output ,
124+ mmm_mlm_head_output = output .mmm_mlm_head_output ,
125+ mmm_mim_head_output = output .mmm_mim_head_output ,
126+ )
127+ return loss
108128
109129 def configure_optimizers (self ):
130+ parameters = self .model .parameters () + self .loss .parameters ()
110131 return get_optimizers_for_lightning (
111- self . model ,
132+ parameters ,
112133 self .learning_rate ,
113134 self .adam_eps ,
114135 self .adam_weight_decay ,
@@ -194,7 +215,7 @@ def _step(self, batch, batch_idx):
194215
195216 def configure_optimizers (self ):
196217 return get_optimizers_for_lightning (
197- self .model ,
218+ self .model . parameters () ,
198219 self .learning_rate ,
199220 self .adam_eps ,
200221 self .adam_weight_decay ,
0 commit comments