diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index d514dbb20..f6b726120 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -13,32 +13,7 @@ MultimodalTransformer, ) from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower - -try: - from transformers import ( - BeamSearchScorer, - LogitsProcessorList, - TopPLogitsWarper, - TopKLogitsWarper, - RepetitionPenaltyLogitsProcessor, - MinLengthLogitsProcessor, - MaxLengthCriteria, - StoppingCriteriaList - ) - - GENERATION_TYPES = { - "top_k": TopKLogitsWarper, - "top_p": TopPLogitsWarper, - "beam_search": "beam_search" - } - _has_transformers = True -except ImportError as e: - GENERATION_TYPES = { - "top_k": None, - "top_p": None, - "beam_search": "beam_search" - } - _has_transformers = False +from .generation_utils import Generator @dataclass @@ -48,6 +23,10 @@ class MultimodalCfg(CLIPTextCfg): heads: int = 8 n_queries: int = 256 attn_pooler_heads: int = 8 + cross_attn_ratio: int = 1 + does_full_decoding: bool = False + output_tokens: bool = False + has_mlp: bool = True def _build_text_decoder_tower( @@ -55,6 +34,7 @@ def _build_text_decoder_tower( multimodal_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, + is_decoder=True, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU @@ -68,7 +48,10 @@ def _build_text_decoder_tower( heads=multimodal_cfg.heads, layers=multimodal_cfg.layers, ls_init_value=multimodal_cfg.ls_init_value, + cross_attn_ratio=multimodal_cfg.cross_attn_ratio, + has_mlp=multimodal_cfg.has_mlp, output_dim=embed_dim, + output_tokens=multimodal_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, ) @@ -76,7 +59,7 @@ def _build_text_decoder_tower( return decoder -class CoCa(nn.Module): +class CoCa(nn.Module, Generator): def __init__( self, embed_dim, @@ -148,13 +131,8 @@ def encode_text(self, text, normalize: bool = True): text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent - def forward( - self, - image, - text: Optional[torch.Tensor] = None, - image_latent: Optional[torch.Tensor] = None, - image_embs: Optional[torch.Tensor] = None, - ): + def forward(self, image=None, text=None, image_latent=None, image_embs=None, is_training=True): + text_latent, token_embs = self._encode_text(text) if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) @@ -164,7 +142,9 @@ def forward( text_latent, token_embs = self._encode_text(text) # TODO: add assertion to avoid bugs? - labels = text[:, -token_embs.shape[1]:] + labels = text[:, 1:] + if is_training: + token_embs = token_embs[:, :-1] logits = self.text_decoder(image_embs, token_embs) return { @@ -175,295 +155,3 @@ def forward( "logit_scale": self.logit_scale.exp() } - def generate( - self, - image, - text=None, - seq_len=30, - max_seq_len=77, - temperature=1., - generation_type="beam_search", - top_p=0.1, # keep tokens in the 1 - top_p quantile - top_k=1, # keeps the top_k most probable tokens - pad_token_id=None, - eos_token_id=None, - sot_token_id=None, - num_beams=6, - num_beam_groups=3, - min_seq_len=5, - stopping_criteria=None, - repetition_penalty=1.0, - fixed_output_length=False # if True output.shape == (batch_size, seq_len) - ): - # taking many ideas and components from HuggingFace GenerationMixin - # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation - assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." - assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" - - with torch.no_grad(): - sot_token_id = 49406 if sot_token_id is None else sot_token_id - eos_token_id = 49407 if eos_token_id is None else eos_token_id - pad_token_id = self.pad_id if pad_token_id is None else pad_token_id - logit_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(min_seq_len, eos_token_id), - RepetitionPenaltyLogitsProcessor(repetition_penalty), - ] - ) - - if stopping_criteria is None: - stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] - - stopping_criteria = StoppingCriteriaList( - stopping_criteria - ) - - device = image.device - - if generation_type == "beam_search": - output = self._generate_beamsearch( - image_inputs=image, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - sot_token_id=sot_token_id, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - min_seq_len=min_seq_len, - stopping_criteria=stopping_criteria, - logit_processor=logit_processor, - ) - if fixed_output_length and output.shape[1] < seq_len: - return torch.cat( - (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), - dim=1 - ) - return output - - elif generation_type == "top_p": - logit_warper = GENERATION_TYPES[generation_type](top_p) - elif generation_type == "top_k": - logit_warper = GENERATION_TYPES[generation_type](top_k) - else: - raise ValueError( - f"generation_type has to be one of " - f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." - ) - - image_latent, image_embs = self._encode_image(image) - - if text is None: - text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id - - was_training = self.training - num_dims = len(text.shape) - - if num_dims == 1: - text = text[None, :] - - cur_len = text.shape[1] - self.eval() - out = text - - while True: - x = out[:, -max_seq_len:] - cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs)["logits"][:, -1] - mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) - sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id - - if mask.all(): - if not fixed_output_length: - break - else: - logits = logits[~mask, :] - filtered_logits = logit_processor(x[~mask, :], logits) - filtered_logits = logit_warper(x[~mask, :], filtered_logits) - probs = F.softmax(filtered_logits / temperature, dim=-1) - - if (cur_len + 1 == seq_len): - sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id - else: - sample[~mask, :] = torch.multinomial(probs, 1) - - out = torch.cat((out, sample), dim=-1) - - cur_len += 1 - - if stopping_criteria(out, None): - break - - if num_dims == 1: - out = out.squeeze(0) - - self.train(was_training) - return out - - def _generate_beamsearch( - self, - image_inputs, - pad_token_id=None, - eos_token_id=None, - sot_token_id=None, - num_beams=6, - num_beam_groups=3, - min_seq_len=5, - stopping_criteria=None, - logit_processor=None, - logit_warper=None, - ): - device = image_inputs.device - batch_size = image_inputs.shape[0] - image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) - image_latent, image_embs = self._encode_image(image_inputs) - - input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) - input_ids = input_ids * sot_token_id - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=device, - num_beam_groups=num_beam_groups, - ) - # instantiate logits processors - logits_processor = ( - LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) - if logit_processor is None - else logit_processor - ) - - num_beams = beam_scorer.num_beams - num_beam_groups = beam_scorer.num_beam_groups - num_sub_beams = num_beams // num_beam_groups - batch_size = len(beam_scorer._beam_hyps) // num_beam_groups - batch_beam_size, cur_len = input_ids.shape - beam_indices = None - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) - # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in - # the same group don't produce same tokens everytime. - beam_scores[:, ::num_sub_beams] = 0 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - while True: - - # predicted tokens in cur_len step - current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) - - # indices which will form the beams in the next time step - reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) - - # do one decoder step on all beams of all sentences in batch - model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) - outputs = self( - model_inputs['images'], - model_inputs['text'], - image_latent=image_latent, - image_embs=image_embs - ) - - for beam_group_idx in range(num_beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - - for batch_idx in range(batch_size): - batch_group_indices.extend( - [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] - ) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of currentg group only - next_token_logits = outputs['logits'][batch_group_indices, -1, :] - vocab_size = next_token_logits.shape[-1] - - next_token_scores_processed = logits_processor( - group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx - ) - next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as(next_token_scores_processed) - - # reshape for beam search - next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=process_beam_indices, - group_index=beam_group_idx, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - current_tokens[batch_group_indices] = group_input_ids[:, -1] - - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) - ) - - input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - - # increase cur_len - cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, None): - break - - final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=final_beam_indices, - ) - return sequence_outputs['sequences'] - - -def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - else: - position_ids = None - return { - "text": input_ids, - "images": image_inputs, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask, - } diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 68456d758..84c311218 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -12,6 +12,8 @@ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg +from .mammut_model import MaMMUT + from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss from .openai import load_openai_model @@ -20,6 +22,7 @@ from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH +_IMAGE_CAPTIONING_MODELS = ["coca", "mammut"] HF_HUB_PREFIX = 'hf-hub:' _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] @@ -244,6 +247,8 @@ def create_model( if custom_text: if "multimodal_cfg" in model_cfg: model = CoCa(**model_cfg, cast_dtype=cast_dtype) + elif "mammut" in model_name: + model = MaMMUT(**model_cfg, cast_dtype=cast_dtype) else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: @@ -328,7 +333,7 @@ def create_loss(args): world_size=args.world_size, use_horovod=args.horovod, ) - elif "coca" in args.model.lower(): + elif any(m in args.model.lower() for m in _IMAGE_CAPTIONING_MODELS): return CoCaLoss( caption_loss_weight=args.coca_caption_loss_weight, clip_loss_weight=args.coca_contrastive_loss_weight, diff --git a/src/open_clip/generation_utils.py b/src/open_clip/generation_utils.py index e69de29bb..bb6a62ac0 100644 --- a/src/open_clip/generation_utils.py +++ b/src/open_clip/generation_utils.py @@ -0,0 +1,328 @@ + +import torch +import torch.nn.functional as F + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + +class Generator: + + def __init__(self): + super().__init__() + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + + with torch.no_grad(): + sot_token_id = 49406 if sot_token_id is None else sot_token_id + eos_token_id = 49407 if eos_token_id is None else eos_token_id + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + + stopping_criteria = StoppingCriteriaList( + stopping_criteria + ) + + device = image.device + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs = image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + return torch.cat( + (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + cur_len = text.shape[1] + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self(image, x, image_latent=image_latent, image_embs=image_embs, is_training=False)["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if stopping_criteria(out, None): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + image_latent=image_latent, + image_embs=image_embs, + is_training=False + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or stopping_criteria(input_ids, None): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } diff --git a/src/open_clip/loss.py b/src/open_clip/loss.py index 5beaab1c3..175c52e7a 100644 --- a/src/open_clip/loss.py +++ b/src/open_clip/loss.py @@ -114,7 +114,7 @@ def get_logits(self, image_features, text_features, logit_scale): else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T - + return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale, output_dict=False): @@ -158,9 +158,9 @@ def __init__( self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): - + clip_loss = torch.tensor(0) - + if self.clip_loss_weight: clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss diff --git a/src/open_clip/mammut_model.py b/src/open_clip/mammut_model.py new file mode 100644 index 000000000..37eba5ce4 --- /dev/null +++ b/src/open_clip/mammut_model.py @@ -0,0 +1,167 @@ +from typing import Optional + +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower +from .coca_model import MultimodalCfg +from .transformer import QuickGELU, LayerNormFp32, LayerNorm, MultimodalTransformer +from .generation_utils import Generator + + + +def _build_multimodal_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + cross_attn_ratio=multimodal_cfg.cross_attn_ratio, + does_full_decoding=multimodal_cfg.does_full_decoding, + output_tokens=multimodal_cfg.output_tokens, + has_mlp=multimodal_cfg.has_mlp, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + +class MaMMUT(nn.Module, Generator): + def __init__( + self, + embed_dim: int, + text_cfg: MultimodalCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = ( + CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + ) + + vocab_size = ( + self.text.config.vocab_size # for hf models + if multimodal_cfg.__dict__.get("hf_model_name", None) is not None + else multimodal_cfg.vocab_size + ) + + self.text = _build_multimodal_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.context_length = multimodal_cfg.context_length + self.map_viz2txt_kv = nn.Parameter(torch.randn(vision_cfg.width, multimodal_cfg.width)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.pad_id = pad_id + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + def _encode_text(self, text, image_embs): + token_logits, text_latent = self.text( + text_embs=text, + image_embs=image_embs, + ) + return token_logits, text_latent + + + def encode_text( + self, + text, + image_embs=None, + normalize=True, + output_logits=False + ): + token_logits, text_latent = self._encode_text( + text=text, + image_embs=image_embs, + ) + + if output_logits: + return token_logits + + text_latent = text_latent.mean(1) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent + + def _encode_image(self, image, normalize: bool=True): + image_latent, image_embs = self.visual(image) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, image_embs + + def encode_image(self, image, normalize: bool=True): + image_latent, _ = self._encode_image(image, normalize=normalize) + return image_latent + + def _forward(self, text, out, image_embs=None, contrastive=True, is_training=True): + + if contrastive: + text_features = self.encode_text(text) + out["text_features"] = text_features + return out + + # adjust image output size for cross_attn + image_embs = image_embs @ self.map_viz2txt_kv + + # TODO: add assertion to avoid bugs? + out["labels"] = text[:, 1:] # shift labels + + text = text[:, :-1] if is_training else text # drop last tok because it has no label + out["logits"] = self.encode_text(text, image_embs=image_embs, output_logits=True) + + return out + + def forward(self, image, text=None, image_latent=None, image_embs=None, is_training=True): + out = {"logit_scale": self.logit_scale.exp()} + + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + out["image_features"] = image_latent + + if text is None: + return out + + if is_training: + out = self._forward(text=text, out=out) + + out = self._forward( + text=text, + out=out, + image_embs=image_embs, + contrastive=False, + is_training=is_training, + ) + + + return out diff --git a/src/open_clip/model_configs/mammut_ViT-B-32.json b/src/open_clip/model_configs/mammut_ViT-B-32.json new file mode 100644 index 000000000..6eb7cabbd --- /dev/null +++ b/src/open_clip/model_configs/mammut_ViT-B-32.json @@ -0,0 +1,24 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32, + "output_tokens": true, + "pool_type": "avg_all", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12, + "output_tokens": true, + "cross_attn_ratio": 2, + "does_full_decoding": true, + "has_mlp": false + }, + "custom_text": true +} \ No newline at end of file diff --git a/src/open_clip/model_configs/mammut_ViT-L-14.json b/src/open_clip/model_configs/mammut_ViT-L-14.json new file mode 100644 index 000000000..1e1b1ffbe --- /dev/null +++ b/src/open_clip/model_configs/mammut_ViT-L-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14, + "output_tokens": true, + "pool_type": "avg_all", + "final_ln_after_pool": true + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12, + "output_tokens": true, + "cross_attn_ratio": 2, + "does_full_decoding": true + }, + "custom_text": true +} diff --git a/src/open_clip/transformer.py b/src/open_clip/transformer.py index 6d4e604d8..100d813f7 100644 --- a/src/open_clip/transformer.py +++ b/src/open_clip/transformer.py @@ -195,6 +195,7 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, is_cross_attention: bool = False, + has_mlp: bool = True, ): super().__init__() @@ -204,14 +205,20 @@ def __init__( if is_cross_attention: self.ln_1_kv = norm_layer(d_model) - self.ln_2 = norm_layer(d_model) - mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) - self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.has_mlp = has_mlp + if self.has_mlp: + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + else: + self.ln2 = None + self.mlp = None + self.ls_2 = None def attention( self, @@ -239,7 +246,8 @@ def forward( v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) - x = x + self.ls_2(self.mlp(self.ln_2(x))) + if self.has_mlp: + x = x + self.ls_2(self.mlp(self.ln_2(x))) return x @@ -324,6 +332,18 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = r(x, attn_mask=attn_mask) return x + def init_parameters(self): + proj_std = (self.width ** -0.5) * ((2 * self.layers) ** -0.5) + attn_std = self.width ** -0.5 + fc_std = (2 * self.width) ** -0.5 + for block in self.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + return proj_std, attn_std, fc_std + class VisionTransformer(nn.Module): output_tokens: torch.jit.Final[bool] @@ -351,7 +371,7 @@ def __init__( output_tokens: bool = False, ): super().__init__() - assert pool_type in ('tok', 'avg', 'none') + assert pool_type in ('tok', 'avg', 'avg_all', 'none') self.output_tokens = output_tokens image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) @@ -492,6 +512,8 @@ def set_grad_checkpointing(self, enable=True): def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.pool_type == 'avg': pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] + elif self.pool_type == "avg_all": + pooled, tokens = x.mean(dim=1), x elif self.pool_type == 'tok': pooled, tokens = x[:, 0], x[:, 1:] else: @@ -543,7 +565,7 @@ def forward(self, x: torch.Tensor): if self.output_tokens: return pooled, tokens - + return pooled @@ -584,6 +606,7 @@ def __init__( norm_layer: Callable = LayerNorm, output_tokens: bool = False, ): + super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') self.output_tokens = output_tokens @@ -602,6 +625,8 @@ def __init__( else: self.cls_emb = None self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + + self.transformer = Transformer( width=width, layers=layers, @@ -611,6 +636,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, ) + self.ln_final = norm_layer(width) if no_causal_mask: @@ -631,15 +657,7 @@ def init_parameters(self): if self.cls_emb is not None: nn.init.normal_(self.cls_emb, std=0.01) - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - + self.transformer.init_parameters() if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) @@ -662,7 +680,7 @@ def build_causal_mask(self): def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + cls_mask = F.pad(cls_mask, (0, 1, cls_mask.shape[2], 0), value=1.0) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) @@ -674,6 +692,7 @@ def forward(self, text): seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask if self.cls_emb is not None: seq_len += 1 @@ -684,7 +703,7 @@ def forward(self, text): x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask) + x = self.transformer(x, attn_mask) x = x.permute(1, 0, 2) # LND -> NLD # x.shape = [batch_size, n_ctx, transformer.width] @@ -708,7 +727,9 @@ def forward(self, text): return pooled -class MultimodalTransformer(Transformer): +class MultimodalTransformer(nn.Module): + does_full_decoding: torch.jit.Final[bool] + def __init__( self, width: int, @@ -719,84 +740,149 @@ def __init__( ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, + cross_attn_ratio = 1, output_dim: int = 512, + does_full_decoding: bool = False, # if this is false below values are useless + vocab_size: int = 49408, + output_tokens: bool = False, + has_mlp: bool = True, ): - super().__init__( - width=width, - layers=layers, - heads=heads, - mlp_ratio=mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - ) + super().__init__() + + self.width = width + self.layers = layers + self.grad_checkpointing = False self.context_length = context_length - self.cross_attn = nn.ModuleList([ - ResidualAttentionBlock( - width, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - is_cross_attention=True, + + n_cross_attn, _ = divmod(layers, cross_attn_ratio) + self.cross_step, _ = divmod(layers, n_cross_attn) + + self.resblocks = nn.ModuleList([]) + self.cross_attn = nn.ModuleList([]) + + for l_idx in range(layers): + + _, r = divmod(l_idx, self.cross_step) + has_cross_attn = r == 0 + + self.resblocks.append( + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + has_mlp=(not has_cross_attn) or has_mlp, + ) ) - for _ in range(layers) - ]) + + if has_cross_attn: + self.cross_attn.append( + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + ) + + assert len(self.cross_attn) == n_cross_attn, "the number of cross attn is incorrect" self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + self.does_full_decoding = does_full_decoding + + if self.does_full_decoding: + self.num_pos = self.context_length + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + else: + self.num_pos = None + self.token_embedding = None + self.positional_embedding = None + + self.output_tokens = output_tokens + + self.init_parameters() def init_parameters(self): - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - for block in self.transformer.cross_attn: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + nn.init.zeros_(self.text_projection) + + if self.does_full_decoding: + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal + mask.triu_(1) # zero out the lower diagonal return mask + def get_cast_dtype(self) -> torch.dtype: + for resblock in self.resblocks: + if hasattr(resblock, 'mlp') and resblock.mlp is not None: + if hasattr(resblock.mlp.c_fc, 'int8_original_dtype'): + return resblock.mlp.c_fc.int8_original_dtype + return resblock.mlp.c_fc.weight.dtype + def forward(self, image_embs, text_embs): - text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq - image_embs = image_embs.permute(1, 0, 2) # NLD -> LND - seq_len = text_embs.shape[0] + seq_len = text_embs.shape[1] + if self.does_full_decoding: + cast_dtype = self.get_cast_dtype() + text_embs = self.token_embedding(text_embs).to(cast_dtype) # [batch_size, n_ctx, d_model] + text_embs = text_embs + self.positional_embedding[:seq_len].to(cast_dtype) + + text_embs = text_embs.permute(1, 0, 2) # NLD -> LND + if image_embs is not None: + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + + # TODO: handle different cases better, currently + # differentiates coca from mammut based on image_embs + if image_embs is not None: + attn_mask = self.attn_mask + attn_mask = attn_mask[:seq_len, :seq_len] + else: + attn_mask = None + + for idx, resblock in enumerate(self.resblocks): + cross_attn_idx, r = divmod(idx, self.cross_step) + do_cross_attn = r == 0 and image_embs is not None - for resblock, cross_attn in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 - text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) - text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + text_embs = checkpoint(resblock, text_embs, None, None, attn_mask) + if do_cross_attn: + cross_attn = self.cross_attn[cross_attn_idx] + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs) else: - text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) - text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + text_embs = resblock(text_embs, None, None, attn_mask=attn_mask) + if do_cross_attn: + cross_attn = self.cross_attn[cross_attn_idx] + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + assert cross_attn_idx == len(self.cross_attn) - 1, "some cross attentions are being skipped" x = text_embs.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x) if self.text_projection is not None: - x = x @ self.text_projection + logits = x @ self.text_projection - return x + if self.output_tokens: + return logits, x + + return logits @torch.jit.ignore def set_grad_checkpointing(self, enable=True): diff --git a/tests/test_inference.py b/tests/test_inference.py index dca8dc44c..015308f5f 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -28,7 +28,6 @@ 'ViT-e-14', 'mt5-xl-ViT-H-14', 'coca_base', - 'coca_ViT-B-32', 'coca_roberta-ViT-B-32' })