6262FLAVA_FOR_PRETRAINED_MAPPING = {
6363 # This will no longer load with the updated model, but keeping here just in case
6464 # "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
65- "flava_full" : "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm .pt" ,
65+ "flava_full" : "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_unified_itm_mp .pt" ,
6666}
6767
6868FLAVA_MODEL_MAPPING = {
@@ -314,6 +314,50 @@ def forward(self, hidden_states: Tensor):
314314 return logits
315315
316316
317+ class MaskedPredictionHead (nn .Module ):
318+ def __init__ (
319+ self ,
320+ hidden_size : int = 768 ,
321+ vocab_size : int = 30522 ,
322+ transform_act_fn : Callable [[Tensor ], Tensor ] = nn .functional .gelu ,
323+ layer_norm_eps : float = 1e-5 ,
324+ use_fp32_layer_norm : bool = True ,
325+ ignore_index : int = - 1 ,
326+ ** kwargs : Any ,
327+ ):
328+ super ().__init__ ()
329+
330+ self .dense = nn .Linear (hidden_size , hidden_size )
331+ self .transform_act_fn = transform_act_fn
332+
333+ self .layer_norm : nn .LayerNorm
334+ if use_fp32_layer_norm :
335+ self .layer_norm = Fp32LayerNorm (hidden_size , eps = layer_norm_eps )
336+ else :
337+ self .layer_norm = nn .LayerNorm (hidden_size , eps = layer_norm_eps )
338+
339+ # The output weights are the same as the input embeddings, but there is
340+ # an output-only bias for each token.
341+ self .decoder = nn .Linear (hidden_size , vocab_size , bias = False )
342+
343+ self .bias = nn .Parameter (torch .zeros (vocab_size ))
344+
345+ # Need a link between the two variables so that the bias is
346+ # correctly resized with `resize_token_embeddings`
347+ self .decoder .bias = self .bias
348+ self .ignore_index = ignore_index
349+
350+ def forward (self , hidden_states : Tensor , masked_labels : Tensor ) -> Tensor :
351+ masked_tokens = masked_labels .ne (self .ignore_index )
352+ sequence_output = hidden_states [masked_tokens , :]
353+
354+ head_output = self .dense (sequence_output )
355+ head_output = self .transform_act_fn (head_output )
356+ head_output = self .layer_norm (head_output )
357+ head_output = self .decoder (head_output )
358+ return head_output
359+
360+
317361class FLAVAForPreTraining (nn .Module , PretrainedMixin ):
318362 # TODOs:
319363 # 1. Expose logit scale
@@ -325,12 +369,20 @@ def __init__(
325369 image_codebook : nn .Module ,
326370 loss : nn .Module ,
327371 itm_head : nn .Module ,
372+ mlm_head : nn .Module ,
373+ mim_head : nn .Module ,
374+ mmm_mlm_head : nn .Module ,
375+ mmm_mim_head : nn .Module ,
328376 ):
329377 super ().__init__ ()
330378 self .model = model
331379 self .image_codebook = image_codebook
332380 self .loss = loss
333381 self .itm_head = itm_head
382+ self .mlm_head = mlm_head
383+ self .mim_head = mim_head
384+ self .mmm_mlm_head = mmm_mlm_head
385+ self .mmm_mim_head = mmm_mim_head
334386
335387 def encode_image (
336388 self ,
@@ -380,24 +432,78 @@ def forward(
380432 )
381433 multimodal_masked_sequence = flava_output .multimodal_masked .last_hidden_state
382434 itm_logits = None
435+
436+ image_masked_sequence = flava_output .image_masked .last_hidden_state
437+ text_masked_sequence = flava_output .text_masked .last_hidden_state
438+ mlm_head_output = (
439+ mim_head_output
440+ ) = mmm_mlm_head_output = mmm_mim_head_output = None
441+ pos_mask = None
442+ if image_masked_sequence is not None and multimodal_masked_sequence is None :
443+ # Remove CLS token from image_masked_sequence
444+ start_index = - image_labels .size (1 ) if image_labels is not None else 1
445+ mim_head_output = self .mim_head (
446+ image_masked_sequence [:, start_index :, :], image_labels
447+ )
448+
449+ if text_masked_sequence is not None and multimodal_masked_sequence is None :
450+ start_index = - mlm_labels .size (1 ) if mlm_labels is not None else 1
451+ mlm_head_output = self .mlm_head (
452+ text_masked_sequence [:, start_index :, :], mlm_labels
453+ )
454+
383455 if multimodal_masked_sequence is not None :
456+ if itm_labels is not None :
457+ pos_pairs = itm_labels .ne (0 )
458+ pos_mask = torch .where (
459+ pos_pairs .any (), pos_pairs , pos_pairs .new ([True ])
460+ )
461+ else :
462+ pos_mask = torch .ones (
463+ multimodal_masked_sequence .size (0 ),
464+ device = multimodal_masked_sequence .device ,
465+ ).bool ()
384466 itm_logits = self .itm_head (multimodal_masked_sequence )
385467
468+ multimodal_masked_sequence = multimodal_masked_sequence [pos_mask ]
469+ if mlm_labels is not None :
470+ mlm_labels = mlm_labels [pos_mask ]
471+ if image_labels is not None :
472+ image_labels = image_labels [pos_mask ]
473+
474+ if multimodal_masked_sequence is not None :
475+ start_index = (
476+ - mlm_labels .size (1 )
477+ if mlm_labels is not None
478+ else - (text_masked_sequence .size (1 ) - 1 )
479+ )
480+ sequence_for_text = multimodal_masked_sequence [:, start_index :, :]
481+ mmm_mlm_head_output = self .mmm_mlm_head (sequence_for_text , mlm_labels )
482+
483+ if multimodal_masked_sequence is not None :
484+ # Starts from 2 because of 2 CLS, one for multimodal encoder and one
485+ # that comes from image encoder.
486+ total_indices = (
487+ image_labels .size (1 )
488+ if image_labels is not None
489+ else (image_masked_sequence .size (1 ) - 1 )
490+ )
491+ sequence_for_image = multimodal_masked_sequence [:, 2 : 2 + total_indices , :]
492+ mmm_mim_head_output = self .mmm_mim_head (sequence_for_image , image_labels )
493+
386494 return self .loss (
387- image_sequence = flava_output .image .last_hidden_state ,
388- text_sequence = flava_output .text .last_hidden_state ,
389- image_masked_sequence = flava_output .image_masked .last_hidden_state ,
390- text_masked_sequence = flava_output .text_masked .last_hidden_state ,
391- multimodal_sequence = flava_output .multimodal .last_hidden_state
392- if not skip_unmasked_mm_encoder
393- else None ,
394495 multimodal_masked_sequence = flava_output .multimodal_masked .last_hidden_state ,
496+ pos_mask = pos_mask ,
395497 itm_labels = itm_labels ,
396498 mim_labels = image_labels ,
397499 mlm_labels = mlm_labels ,
398500 projected_image_embeddings = flava_output .projected_image_embeddings ,
399501 projected_text_embeddings = flava_output .projected_text_embeddings ,
400502 itm_logits = itm_logits ,
503+ mlm_head_output = mlm_head_output ,
504+ mim_head_output = mim_head_output ,
505+ mmm_mlm_head_output = mmm_mlm_head_output ,
506+ mmm_mim_head_output = mmm_mim_head_output ,
401507 )
402508
403509
@@ -548,17 +654,36 @@ def flava_model(
548654def flava_model_for_pretraining (
549655 codebook_image_size : int = 112 ,
550656 pretrained_model_key : Optional [str ] = None ,
657+ image_vocab_size : int = 8192 ,
551658 ** flava_model_kwargs : Any ,
552659 # TODO: Add parameters for loss here
553660) -> FLAVAForPreTraining :
554661 model = flava_model (** flava_model_kwargs )
555662 hidden_size = flava_model_kwargs .get ("hidden_size" ) or 768
663+ text_vocab_size = flava_model_kwargs .get ("vocab_size" ) or 30522
556664 itm_head = ITMHead (hidden_size )
665+ mlm_head = MaskedPredictionHead (hidden_size = hidden_size , vocab_size = text_vocab_size )
666+ mim_head = MaskedPredictionHead (
667+ hidden_size = hidden_size , vocab_size = image_vocab_size
668+ )
669+ mmm_mlm_head = MaskedPredictionHead (
670+ hidden_size = hidden_size , vocab_size = text_vocab_size
671+ )
672+ mmm_mim_head = MaskedPredictionHead (
673+ hidden_size = hidden_size , vocab_size = image_vocab_size
674+ )
557675 losses = FLAVAPretrainingLoss ()
558676 codebook = DalleVAEEncoder (image_size = codebook_image_size )
559677
560678 flava = FLAVAForPreTraining (
561- model = model , image_codebook = codebook , loss = losses , itm_head = itm_head
679+ model = model ,
680+ image_codebook = codebook ,
681+ loss = losses ,
682+ itm_head = itm_head ,
683+ mlm_head = mlm_head ,
684+ mim_head = mim_head ,
685+ mmm_mlm_head = mmm_mlm_head ,
686+ mmm_mim_head = mmm_mim_head ,
562687 )
563688
564689 if pretrained_model_key is not None :
0 commit comments