Skip to content

Commit ff9b47c

Browse files
committed
[FLAVA] Move masked prediction head to flava_for_pretraining
ghstack-source-id: 081f266 Pull Request resolved: #195
1 parent 6c4880a commit ff9b47c

File tree

2 files changed

+154
-148
lines changed

2 files changed

+154
-148
lines changed

torchmultimodal/models/flava/model.py

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
FLAVA_FOR_PRETRAINED_MAPPING = {
6363
# This will no longer load with the updated model, but keeping here just in case
6464
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
65-
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm.pt",
65+
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp.pt",
6666
}
6767

6868
FLAVA_MODEL_MAPPING = {
@@ -314,6 +314,50 @@ def forward(self, hidden_states: Tensor):
314314
return logits
315315

316316

317+
class MaskedPredictionHead(nn.Module):
318+
def __init__(
319+
self,
320+
hidden_size: int = 768,
321+
vocab_size: int = 30522,
322+
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
323+
layer_norm_eps: float = 1e-5,
324+
use_fp32_layer_norm: bool = True,
325+
ignore_index: int = -1,
326+
**kwargs: Any,
327+
):
328+
super().__init__()
329+
330+
self.dense = nn.Linear(hidden_size, hidden_size)
331+
self.transform_act_fn = transform_act_fn
332+
333+
self.layer_norm: nn.LayerNorm
334+
if use_fp32_layer_norm:
335+
self.layer_norm = Fp32LayerNorm(hidden_size, eps=layer_norm_eps)
336+
else:
337+
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
338+
339+
# The output weights are the same as the input embeddings, but there is
340+
# an output-only bias for each token.
341+
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
342+
343+
self.bias = nn.Parameter(torch.zeros(vocab_size))
344+
345+
# Need a link between the two variables so that the bias is
346+
# correctly resized with `resize_token_embeddings`
347+
self.decoder.bias = self.bias
348+
self.ignore_index = ignore_index
349+
350+
def forward(self, hidden_states: Tensor, masked_labels: Tensor) -> Tensor:
351+
masked_tokens = masked_labels.ne(self.ignore_index)
352+
sequence_output = hidden_states[masked_tokens, :]
353+
354+
head_output = self.dense(sequence_output)
355+
head_output = self.transform_act_fn(head_output)
356+
head_output = self.layer_norm(head_output)
357+
head_output = self.decoder(head_output)
358+
return head_output
359+
360+
317361
class FLAVAForPreTraining(nn.Module, PretrainedMixin):
318362
# TODOs:
319363
# 1. Expose logit scale
@@ -325,12 +369,20 @@ def __init__(
325369
image_codebook: nn.Module,
326370
loss: nn.Module,
327371
itm_head: nn.Module,
372+
mlm_head: nn.Module,
373+
mim_head: nn.Module,
374+
mmm_mlm_head: nn.Module,
375+
mmm_mim_head: nn.Module,
328376
):
329377
super().__init__()
330378
self.model = model
331379
self.image_codebook = image_codebook
332380
self.loss = loss
333381
self.itm_head = itm_head
382+
self.mlm_head = mlm_head
383+
self.mim_head = mim_head
384+
self.mmm_mlm_head = mmm_mlm_head
385+
self.mmm_mim_head = mmm_mim_head
334386

335387
def encode_image(
336388
self,
@@ -380,24 +432,78 @@ def forward(
380432
)
381433
multimodal_masked_sequence = flava_output.multimodal_masked.last_hidden_state
382434
itm_logits = None
435+
436+
image_masked_sequence = flava_output.image_masked.last_hidden_state
437+
text_masked_sequence = flava_output.text_masked.last_hidden_state
438+
mlm_head_output = (
439+
mim_head_output
440+
) = mmm_mlm_head_output = mmm_mim_head_output = None
441+
pos_mask = None
442+
if image_masked_sequence is not None and multimodal_masked_sequence is None:
443+
# Remove CLS token from image_masked_sequence
444+
start_index = -image_labels.size(1) if image_labels is not None else 1
445+
mim_head_output = self.mim_head(
446+
image_masked_sequence[:, start_index:, :], image_labels
447+
)
448+
449+
if text_masked_sequence is not None and multimodal_masked_sequence is None:
450+
start_index = -mlm_labels.size(1) if mlm_labels is not None else 1
451+
mlm_head_output = self.mlm_head(
452+
text_masked_sequence[:, start_index:, :], mlm_labels
453+
)
454+
383455
if multimodal_masked_sequence is not None:
456+
if itm_labels is not None:
457+
pos_pairs = itm_labels.ne(0)
458+
pos_mask = torch.where(
459+
pos_pairs.any(), pos_pairs, pos_pairs.new([True])
460+
)
461+
else:
462+
pos_mask = torch.ones(
463+
multimodal_masked_sequence.size(0),
464+
device=multimodal_masked_sequence.device,
465+
).bool()
384466
itm_logits = self.itm_head(multimodal_masked_sequence)
385467

468+
multimodal_masked_sequence = multimodal_masked_sequence[pos_mask]
469+
if mlm_labels is not None:
470+
mlm_labels = mlm_labels[pos_mask]
471+
if image_labels is not None:
472+
image_labels = image_labels[pos_mask]
473+
474+
if multimodal_masked_sequence is not None:
475+
start_index = (
476+
-mlm_labels.size(1)
477+
if mlm_labels is not None
478+
else -(text_masked_sequence.size(1) - 1)
479+
)
480+
sequence_for_text = multimodal_masked_sequence[:, start_index:, :]
481+
mmm_mlm_head_output = self.mmm_mlm_head(sequence_for_text, mlm_labels)
482+
483+
if multimodal_masked_sequence is not None:
484+
# Starts from 2 because of 2 CLS, one for multimodal encoder and one
485+
# that comes from image encoder.
486+
total_indices = (
487+
image_labels.size(1)
488+
if image_labels is not None
489+
else (image_masked_sequence.size(1) - 1)
490+
)
491+
sequence_for_image = multimodal_masked_sequence[:, 2 : 2 + total_indices, :]
492+
mmm_mim_head_output = self.mmm_mim_head(sequence_for_image, image_labels)
493+
386494
return self.loss(
387-
image_sequence=flava_output.image.last_hidden_state,
388-
text_sequence=flava_output.text.last_hidden_state,
389-
image_masked_sequence=flava_output.image_masked.last_hidden_state,
390-
text_masked_sequence=flava_output.text_masked.last_hidden_state,
391-
multimodal_sequence=flava_output.multimodal.last_hidden_state
392-
if not skip_unmasked_mm_encoder
393-
else None,
394495
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
496+
pos_mask=pos_mask,
395497
itm_labels=itm_labels,
396498
mim_labels=image_labels,
397499
mlm_labels=mlm_labels,
398500
projected_image_embeddings=flava_output.projected_image_embeddings,
399501
projected_text_embeddings=flava_output.projected_text_embeddings,
400502
itm_logits=itm_logits,
503+
mlm_head_output=mlm_head_output,
504+
mim_head_output=mim_head_output,
505+
mmm_mlm_head_output=mmm_mlm_head_output,
506+
mmm_mim_head_output=mmm_mim_head_output,
401507
)
402508

403509

@@ -548,17 +654,36 @@ def flava_model(
548654
def flava_model_for_pretraining(
549655
codebook_image_size: int = 112,
550656
pretrained_model_key: Optional[str] = None,
657+
image_vocab_size: int = 8192,
551658
**flava_model_kwargs: Any,
552659
# TODO: Add parameters for loss here
553660
) -> FLAVAForPreTraining:
554661
model = flava_model(**flava_model_kwargs)
555662
hidden_size = flava_model_kwargs.get("hidden_size") or 768
663+
text_vocab_size = flava_model_kwargs.get("vocab_size") or 30522
556664
itm_head = ITMHead(hidden_size)
665+
mlm_head = MaskedPredictionHead(hidden_size=hidden_size, vocab_size=text_vocab_size)
666+
mim_head = MaskedPredictionHead(
667+
hidden_size=hidden_size, vocab_size=image_vocab_size
668+
)
669+
mmm_mlm_head = MaskedPredictionHead(
670+
hidden_size=hidden_size, vocab_size=text_vocab_size
671+
)
672+
mmm_mim_head = MaskedPredictionHead(
673+
hidden_size=hidden_size, vocab_size=image_vocab_size
674+
)
557675
losses = FLAVAPretrainingLoss()
558676
codebook = DalleVAEEncoder(image_size=codebook_image_size)
559677

560678
flava = FLAVAForPreTraining(
561-
model=model, image_codebook=codebook, loss=losses, itm_head=itm_head
679+
model=model,
680+
image_codebook=codebook,
681+
loss=losses,
682+
itm_head=itm_head,
683+
mlm_head=mlm_head,
684+
mim_head=mim_head,
685+
mmm_mlm_head=mmm_mlm_head,
686+
mmm_mim_head=mmm_mim_head,
562687
)
563688

564689
if pretrained_model_key is not None:

0 commit comments

Comments
 (0)