diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index ba6297a744..f6e8f93b38 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -95,3 +95,4 @@ Functions used for preprocessing images. transforms.resize_with_pad transforms.tile_crop transforms.find_supported_resolutions + transforms.VisionCrossAttentionMask diff --git a/tests/torchtune/modules/transforms/test_transforms.py b/tests/torchtune/modules/transforms/test_transforms.py new file mode 100644 index 0000000000..0436a34ec1 --- /dev/null +++ b/tests/torchtune/modules/transforms/test_transforms.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torchtune.modules.transforms import VisionCrossAttentionMask + + +IMAGE_TOKEN_ID = 1 + + +class TestVisionCrossAttentionMask: + @pytest.fixture + def num_tiles(self): + return 2 + + @pytest.fixture + def tile_size(self): + return 4 + + @pytest.fixture + def patch_size(self): + return 2 + + @pytest.fixture + def image_num_tokens(self, num_tiles, tile_size, patch_size): + return ((tile_size // patch_size) ** 2 + 1) * num_tiles + + @pytest.fixture + def tokens(self): + # This tests image tokens not at start, consecutive images, and image + # with text until end. + # text = 2, image = 1 + return [2, 2, IMAGE_TOKEN_ID, IMAGE_TOKEN_ID, 2, 2, IMAGE_TOKEN_ID, 2, 2] + + @pytest.fixture + def images(self, num_tiles, tokens): + n_img = len([token_id for token_id in tokens if token_id == IMAGE_TOKEN_ID]) + n_channels = 3 + tile_size = 2 + return [ + torch.ones(num_tiles, n_channels, tile_size, tile_size) + for _ in range(n_img) + ] + + @pytest.fixture + def cross_attn_mask_transform(self, tile_size, patch_size): + # patches per tile = 4 + return VisionCrossAttentionMask( + tile_size=tile_size, + patch_size=patch_size, + image_token_id=IMAGE_TOKEN_ID, + ) + + def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens): + actual = cross_attn_mask_transform._get_image_attention_intervals(tokens) + expected = [[2, 6], [3, 6], [6, 9]] + assert actual == expected + + def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens): + sample = {"tokens": tokens, "images": images} + dummy_kwargs = {"hello": 8} + sample.update(dummy_kwargs) + actual = cross_attn_mask_transform(sample) + expected = [ + torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool) + for _ in range(len(images)) + ] + expected[0][2:6, :] = True + expected[1][3:6, :] = True + expected[2][6:9, :] = True + for i in range(len(images)): + torch.testing.assert_close(actual["encoder_mask"][i], expected[i]) + torch.testing.assert_close(actual["images"][i], images[i]) + + assert actual["tokens"] == tokens + assert actual["hello"] == dummy_kwargs["hello"] diff --git a/torchtune/modules/transforms/__init__.py b/torchtune/modules/transforms/__init__.py index afc203812f..c317e7d7ce 100644 --- a/torchtune/modules/transforms/__init__.py +++ b/torchtune/modules/transforms/__init__.py @@ -4,16 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .vision_utils.get_canvas_best_fit import ( # noqa +from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask +from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( # noqa find_supported_resolutions, get_canvas_best_fit, ) -from .vision_utils.resize_with_pad import resize_with_pad # noqa -from .vision_utils.tile_crop import tile_crop # noqa +from torchtune.modules.transforms.vision_utils.resize_with_pad import ( # noqa + resize_with_pad, +) +from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop # noqa __all__ = [ + "Transform", "get_canvas_best_fit", "resize_with_pad", "tile_crop", "find_supported_resolutions", + "VisionCrossAttentionMask", ] diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py new file mode 100644 index 0000000000..68142686fb --- /dev/null +++ b/torchtune/modules/transforms/_transforms.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Protocol + +import torch + + +class Transform(Protocol): + """ + Loose interface for all data and model transforms. Transforms operate at the + sample level and perform operations on a sample dict, returning the updated dict. + """ + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + pass + + +class VisionCrossAttentionMask(Transform): + """ + Computes the cross-attention mask for text + image inputs. Text tokens that + participate in cross-attention with an image token will show True in the mask + and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper + (https://arxiv.org/pdf/2204.14198): + + (1) Text tokens immediately following the image token up until the next image token + (2) Consecutive image tokens attend to subsequent text tokens + + :: + + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + These are two dogs. This is a cat. + + + + Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), + where True indicates that the token outputted from the image encoder attends + to the token in the text sequence in cross-attention. A list of these masks + are returned with length equal to number of images in the sample. + + Args: + tile_size (int): The size of the image tiles from the image transform + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + image_token_id (int): Token ID of the image special token. + """ + + def __init__(self, tile_size: int, patch_size: int, image_token_id: int): + patch_grid_size = tile_size // patch_size + self.patches_per_tile = patch_grid_size**2 + self.image_token_id = image_token_id + + def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: + """ + Returns a list of lists of the form [start, end) where start is the index + of the current image token and end is the index of the next image token, exclusive. + + Args: + tokens (List[int]): List of token IDs in the text sequence + + Returns: + List[List[int]]: List of lists of the form [start, end) indicating + range of positions in text sequence that should attend to the image + + Example: + >>> text = "These are two dogs. This is a cat." + >>> image_token_id = 1 + >>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] + >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) + >>> intervals = transform._get_image_attention_intervals(tokens) + >>> print(intervals) + [[0, 7], [1, 7], [7, 12]] + """ + end = len(tokens) + vision_token_locations = [ + i for i, token in enumerate(tokens) if token == self.image_token_id + ] + # Return empty list if there are no images + if len(vision_token_locations) == 0: + return [] + # If there is only one image, it will attend to subsequent text until end + if len(vision_token_locations) == 1: + return [[vision_token_locations[0], end]] + + # Construct intervals from previous image token to next image token + vision_masks = [ + [tok_idx_prev, tok_idx_next] + # Offset by one to get consecutive indices + for tok_idx_prev, tok_idx_next in zip( + vision_token_locations[:-1], vision_token_locations[1:] + ) + ] + # Last image will attend to subsequent text until end + vision_masks.append([vision_token_locations[-1], end]) + + # If there are consecutive vision tokens, they should all attend to the + # same subsequent text + last_mask_end = vision_masks[-1][1] + for vision_mask in vision_masks[::-1]: + if vision_mask[0] == vision_mask[1] - 1: + vision_mask[1] = last_mask_end + last_mask_end = vision_mask[1] + return vision_masks + + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Generates the vision cross-attention mask for the given sample based on + the image token locations interleaved in the text sequence. + + Args: + sample (Mapping[str, Any]): Sample dict containing the following keys: + - tokens (List[int]): List of token IDs in the text sequence. Number of + image token IDs in the sequence must match the number of images. + - images (List[torch.Tensor]): List of image Tensors post-tiling of shape + (n_tiles, c, h, w) each. + + Returns: + Mapping[str, Any]: updated sample with the following keys: + - encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len), + where length of list == number of images in sample + - tokens (List[int]): original tokens + - images (List[torch.Tensor]): original images + """ + tokens, images = sample["tokens"], sample["images"] + # One sample can have multiple images - verify the number of image tokens + # is the same + n_img = len(images) + intervals = self._get_image_attention_intervals(tokens) + if len(intervals) != n_img: + raise RuntimeError( + f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." + ) + + # Create mask for each individual image based on its number of tokens, + # which can vary based on number of tiles since they are not yet tile padded. + # The masks are padded and concatenated together in the batch collator + text_seq_len = len(tokens) + masks = [] + for image_num, interval in enumerate(intervals): + # Identify what part of text sequence should be attended + start, end = interval + # Compute this image's number of tokens based on num tiles, patches per tile + n_tiles = images[image_num].shape[0] + image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token + # Mask will be block of 1s at the corresponding interval in the text. + # It is not a causal block because all the image tokens correspond + # to a single image, so text tokens attend to all the image's tokens + mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) + mask[start:end, :] = True + masks.append(mask) + + sample.update({"encoder_mask": masks}) + return sample