Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions examples/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@

import torch
from pytorch_lightning import LightningModule
from torch import nn
from torchmetrics import Accuracy
from torchmultimodal.models.flava.model import (
flava_model_for_classification,
flava_model_for_pretraining,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss
from transformers.optimization import get_cosine_schedule_with_warmup


def get_optimizers_for_lightning(
model: torch.nn.Module,
model: nn.Module,
learning_rate: float,
adam_eps: float,
adam_weight_decay: float,
Expand Down Expand Up @@ -59,6 +61,7 @@ def __init__(
self.adam_weight_decay = adam_weight_decay
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.loss = FLAVAPretrainingLoss(logit_scale=self.model.logit_scale)

def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
Expand Down Expand Up @@ -104,7 +107,24 @@ def _step(self, batch, batch_idx):
itm_labels=batch.get("itm_labels", None),
required_embedding=required_embedding,
)
return output

loss = self.loss(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
return loss

def configure_optimizers(self):
return get_optimizers_for_lightning(
Expand Down
52 changes: 51 additions & 1 deletion test/models/flava/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
flava_model_for_classification,
flava_model_for_pretraining,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -139,8 +140,25 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
image_input = inputs_pretraining("image")
text_input = inputs_pretraining("text")
flava = pretraining_model()

losses = FLAVAPretrainingLoss(flava.logit_scale)
output = flava(*mm_input)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)

actual = output.losses
expected = dict(
mmm_text_loss=10.9567,
Expand All @@ -153,6 +171,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
self._assert_tensor_dicts_equal(actual, expected)

output = flava(*image_input)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
actual = output.losses
expected = dict(
mmm_text_loss=None,
Expand All @@ -165,6 +199,22 @@ def test_flava_model_for_pretraining(self, inputs_pretraining, pretraining_model
self._assert_tensor_dicts_equal(actual, expected)

output = flava(*text_input)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
actual = output.losses
expected = dict(
mmm_text_loss=None,
Expand Down
50 changes: 50 additions & 0 deletions test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
FLAVAOutput,
)
from torchmultimodal.modules.layers.transformer import TransformerOutput
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLoss

NUM_CLASSES = 2

Expand Down Expand Up @@ -61,6 +62,7 @@ def test_forward_pretraining(self):
mlm_labels[:, 1:3] = text[:, 1:3]
itm_labels = torch.tensor((0, 1), dtype=torch.long)
flava = flava_model_for_pretraining()
losses = FLAVAPretrainingLoss(flava.logit_scale)
flava.eval()
output = flava(
image=image,
Expand All @@ -72,6 +74,22 @@ def test_forward_pretraining(self):
itm_labels=itm_labels,
mlm_labels=mlm_labels,
)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
self.assertIsNone(output.mlm_output)
self.assertIsNone(output.mim_output)
self.assertIsNotNone(output.global_contrastive_output)
Expand All @@ -96,6 +114,22 @@ def test_forward_pretraining(self):
itm_labels=itm_labels,
mlm_labels=mlm_labels,
)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
self.assertIsNone(output.mlm_output)
self.assertIsNotNone(output.mim_output)
self.assertIsNone(output.global_contrastive_output)
Expand All @@ -120,6 +154,22 @@ def test_forward_pretraining(self):
itm_labels=itm_labels,
mlm_labels=mlm_labels,
)
output = losses(
multimodal_masked_sequence=output.multimodal_masked_sequence,
pos_mask=output.pos_mask,
itm_labels=output.itm_labels,
mim_labels=output.mim_labels,
mlm_labels=output.mlm_labels,
mmm_mlm_labels=output.mmm_mlm_labels,
mmm_mim_labels=output.mmm_mim_labels,
projected_image_embeddings=output.projected_image_embeddings,
projected_text_embeddings=output.projected_text_embeddings,
itm_logits=output.itm_logits,
mlm_head_output=output.mlm_head_output,
mim_head_output=output.mim_head_output,
mmm_mlm_head_output=output.mmm_mlm_head_output,
mmm_mim_head_output=output.mmm_mim_head_output,
)
self.assertIsNotNone(output.mlm_output)
self.assertIsNone(output.mim_output)
self.assertIsNone(output.global_contrastive_output)
Expand Down
43 changes: 30 additions & 13 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
TransformerEncoder,
TransformerOutput,
)
from torchmultimodal.modules.losses.flava import (
FLAVAPretrainingLoss,
FLAVAPretrainingLossOutput,
Pooler,
)
from torchmultimodal.modules.losses.flava import FLAVAPretrainingLossOutput, Pooler
from torchmultimodal.utils.common import ModelOutput, PretrainedMixin
from typing_extensions import Literal

Expand Down Expand Up @@ -62,7 +58,7 @@
FLAVA_FOR_PRETRAINED_MAPPING = {
# This will no longer load with the updated model, but keeping here just in case
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_no_loss.pt",
}

FLAVA_MODEL_MAPPING = {
Expand Down Expand Up @@ -105,6 +101,24 @@ class FLAVAForClassificationOutput(ModelOutput):
loss: Tensor


@dataclass
class FLAVAForPretrainingOutput:
multimodal_masked_sequence: Tensor
pos_mask: Tensor
mim_labels: Tensor
mlm_labels: Tensor
mmm_mlm_labels: Tensor
mmm_mim_labels: Tensor
itm_labels: Tensor
projected_image_embeddings: Tensor
projected_text_embeddings: Tensor
itm_logits: Tensor
mlm_head_output: Tensor
mim_head_output: Tensor
mmm_mlm_head_output: Tensor
mmm_mim_head_output: Tensor


class FLAVAModel(nn.Module, PretrainedMixin):
def __init__(
self,
Expand Down Expand Up @@ -367,22 +381,22 @@ def __init__(
self,
model: FLAVAModel,
image_codebook: nn.Module,
loss: nn.Module,
itm_head: nn.Module,
mlm_head: nn.Module,
mim_head: nn.Module,
mmm_mlm_head: nn.Module,
mmm_mim_head: nn.Module,
logit_scale: nn.Module,
):
super().__init__()
self.model = model
self.image_codebook = image_codebook
self.loss = loss
self.itm_head = itm_head
self.mlm_head = mlm_head
self.mim_head = mim_head
self.mmm_mlm_head = mmm_mlm_head
self.mmm_mim_head = mmm_mim_head
self.logit_scale = logit_scale

def encode_image(
self,
Expand Down Expand Up @@ -469,8 +483,10 @@ def forward(
itm_logits = self.itm_head(multimodal_masked_sequence)

multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]

if mlm_labels is not None:
mmm_mlm_labels = mlm_labels[pos_mask]

if image_labels is not None:
mmm_mim_labels = image_labels[pos_mask]

Expand All @@ -494,14 +510,14 @@ def forward(
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, mmm_mim_labels)

return self.loss(
return FLAVAForPretrainingOutput(
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
pos_mask=pos_mask,
itm_labels=itm_labels,
mim_labels=image_labels,
mlm_labels=mlm_labels,
mmm_mlm_labels=mmm_mlm_labels,
mmm_mim_labels=mmm_mim_labels,
mmm_mlm_labels=mmm_mlm_labels,
itm_labels=itm_labels,
projected_image_embeddings=flava_output.projected_image_embeddings,
projected_text_embeddings=flava_output.projected_text_embeddings,
itm_logits=itm_logits,
Expand Down Expand Up @@ -660,6 +676,7 @@ def flava_model_for_pretraining(
codebook_image_size: int = 112,
pretrained_model_key: Optional[str] = None,
image_vocab_size: int = 8192,
logit_scale: float = math.log(1 / 0.07),
**flava_model_kwargs: Any,
# TODO: Add parameters for loss here
) -> FLAVAForPreTraining:
Expand All @@ -677,18 +694,18 @@ def flava_model_for_pretraining(
mmm_mim_head = MaskedPredictionHead(
hidden_size=hidden_size, vocab_size=image_vocab_size
)
losses = FLAVAPretrainingLoss()

codebook = DalleVAEEncoder(image_size=codebook_image_size)

flava = FLAVAForPreTraining(
model=model,
image_codebook=codebook,
loss=losses,
itm_head=itm_head,
mlm_head=mlm_head,
mim_head=mim_head,
mmm_mlm_head=mmm_mlm_head,
mmm_mim_head=mmm_mim_head,
logit_scale=nn.Parameter(logit_scale * torch.ones([])),
)

if pretrained_model_key is not None:
Expand Down