|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from typing import Any, List, Mapping, Protocol |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | + |
| 12 | +class Transform(Protocol): |
| 13 | + """ |
| 14 | + Loose interface for all data and model transforms. Transforms operate at the |
| 15 | + sample level and perform operations on a sample dict, returning the updated dict. |
| 16 | + """ |
| 17 | + |
| 18 | + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: |
| 19 | + pass |
| 20 | + |
| 21 | + |
| 22 | +class VisionCrossAttentionMask(Transform): |
| 23 | + """ |
| 24 | + Computes the cross-attention mask for text + image inputs. Text tokens that |
| 25 | + participate in cross-attention with an image token will show True in the mask |
| 26 | + and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper |
| 27 | + (https://arxiv.org/pdf/2204.14198): |
| 28 | +
|
| 29 | + (1) Text tokens immediately following the image token up until the next image token |
| 30 | + (2) Consecutive image tokens attend to subsequent text tokens |
| 31 | +
|
| 32 | + :: |
| 33 | +
|
| 34 | + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ |
| 35 | + img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ |
| 36 | + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ |
| 37 | + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ |
| 38 | + img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ |
| 39 | + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ |
| 40 | + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ |
| 41 | + img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ |
| 42 | + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ |
| 43 | + <img1> <img2>These are two dogs. <img3> This is a cat. |
| 44 | +
|
| 45 | +
|
| 46 | +
|
| 47 | + Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), |
| 48 | + where True indicates that the token outputted from the image encoder attends |
| 49 | + to the token in the text sequence in cross-attention. A list of these masks |
| 50 | + are returned with length equal to number of images in the sample. |
| 51 | +
|
| 52 | + Args: |
| 53 | + tile_size (int): The size of the image tiles from the image transform |
| 54 | + patch_size (int): The size of each patch. Used to divide the tiles into patches. |
| 55 | + E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches |
| 56 | + with shape (40, 40) each. |
| 57 | + image_token_id (int): Token ID of the image special token. |
| 58 | + """ |
| 59 | + |
| 60 | + def __init__(self, tile_size: int, patch_size: int, image_token_id: int): |
| 61 | + patch_grid_size = tile_size // patch_size |
| 62 | + self.patches_per_tile = patch_grid_size**2 |
| 63 | + self.image_token_id = image_token_id |
| 64 | + |
| 65 | + def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: |
| 66 | + """ |
| 67 | + Returns a list of lists of the form [start, end) where start is the index |
| 68 | + of the current image token and end is the index of the next image token, exclusive. |
| 69 | +
|
| 70 | + Args: |
| 71 | + tokens (List[int]): List of token IDs in the text sequence |
| 72 | +
|
| 73 | + Returns: |
| 74 | + List[List[int]]: List of lists of the form [start, end) indicating |
| 75 | + range of positions in text sequence that should attend to the image |
| 76 | +
|
| 77 | + Example: |
| 78 | + >>> text = "<img1><img2>These are two dogs. <img3>This is a cat." |
| 79 | + >>> image_token_id = 1 |
| 80 | + >>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] |
| 81 | + >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) |
| 82 | + >>> intervals = transform._get_image_attention_intervals(tokens) |
| 83 | + >>> print(intervals) |
| 84 | + [[0, 7], [1, 7], [7, 12]] |
| 85 | + """ |
| 86 | + end = len(tokens) |
| 87 | + vision_token_locations = [ |
| 88 | + i for i, token in enumerate(tokens) if token == self.image_token_id |
| 89 | + ] |
| 90 | + # Return empty list if there are no images |
| 91 | + if len(vision_token_locations) == 0: |
| 92 | + return [] |
| 93 | + # If there is only one image, it will attend to subsequent text until end |
| 94 | + if len(vision_token_locations) == 1: |
| 95 | + return [[vision_token_locations[0], end]] |
| 96 | + |
| 97 | + # Construct intervals from previous image token to next image token |
| 98 | + vision_masks = [ |
| 99 | + [tok_idx_prev, tok_idx_next] |
| 100 | + # Offset by one to get consecutive indices |
| 101 | + for tok_idx_prev, tok_idx_next in zip( |
| 102 | + vision_token_locations[:-1], vision_token_locations[1:] |
| 103 | + ) |
| 104 | + ] |
| 105 | + # Last image will attend to subsequent text until end |
| 106 | + vision_masks.append([vision_token_locations[-1], end]) |
| 107 | + |
| 108 | + # If there are consecutive vision tokens, they should all attend to the |
| 109 | + # same subsequent text |
| 110 | + last_mask_end = vision_masks[-1][1] |
| 111 | + for vision_mask in vision_masks[::-1]: |
| 112 | + if vision_mask[0] == vision_mask[1] - 1: |
| 113 | + vision_mask[1] = last_mask_end |
| 114 | + last_mask_end = vision_mask[1] |
| 115 | + return vision_masks |
| 116 | + |
| 117 | + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: |
| 118 | + """ |
| 119 | + Generates the vision cross-attention mask for the given sample based on |
| 120 | + the image token locations interleaved in the text sequence. |
| 121 | +
|
| 122 | + Args: |
| 123 | + sample (Mapping[str, Any]): Sample dict containing the following keys: |
| 124 | + - tokens (List[int]): List of token IDs in the text sequence. Number of |
| 125 | + image token IDs in the sequence must match the number of images. |
| 126 | + - images (List[torch.Tensor]): List of image Tensors post-tiling of shape |
| 127 | + (n_tiles, c, h, w) each. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + Mapping[str, Any]: updated sample with the following keys: |
| 131 | + - encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len), |
| 132 | + where length of list == number of images in sample |
| 133 | + - tokens (List[int]): original tokens |
| 134 | + - images (List[torch.Tensor]): original images |
| 135 | + """ |
| 136 | + tokens, images = sample["tokens"], sample["images"] |
| 137 | + # One sample can have multiple images - verify the number of image tokens |
| 138 | + # is the same |
| 139 | + n_img = len(images) |
| 140 | + intervals = self._get_image_attention_intervals(tokens) |
| 141 | + if len(intervals) != n_img: |
| 142 | + raise RuntimeError( |
| 143 | + f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." |
| 144 | + ) |
| 145 | + |
| 146 | + # Create mask for each individual image based on its number of tokens, |
| 147 | + # which can vary based on number of tiles since they are not yet tile padded. |
| 148 | + # The masks are padded and concatenated together in the batch collator |
| 149 | + text_seq_len = len(tokens) |
| 150 | + masks = [] |
| 151 | + for image_num, interval in enumerate(intervals): |
| 152 | + # Identify what part of text sequence should be attended |
| 153 | + start, end = interval |
| 154 | + # Compute this image's number of tokens based on num tiles, patches per tile |
| 155 | + n_tiles = images[image_num].shape[0] |
| 156 | + image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token |
| 157 | + # Mask will be block of 1s at the corresponding interval in the text. |
| 158 | + # It is not a causal block because all the image tokens correspond |
| 159 | + # to a single image, so text tokens attend to all the image's tokens |
| 160 | + mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) |
| 161 | + mask[start:end, :] = True |
| 162 | + masks.append(mask) |
| 163 | + |
| 164 | + sample.update({"encoder_mask": masks}) |
| 165 | + return sample |
0 commit comments