From d90e5a9207a2078389f5dc51b6beca18016776d3 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 2 Jul 2024 12:19:32 -0700 Subject: [PATCH 1/9] add basic transforms --- torchtune/modules/transforms/__init__.py | 17 ++++++ torchtune/modules/transforms/_transforms.py | 61 +++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 torchtune/modules/transforms/__init__.py create mode 100644 torchtune/modules/transforms/_transforms.py diff --git a/torchtune/modules/transforms/__init__.py b/torchtune/modules/transforms/__init__.py new file mode 100644 index 000000000..cce2ac0a0 --- /dev/null +++ b/torchtune/modules/transforms/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.modules.transforms._transforms import ( + Compose, + TokenizeMessages, + Transform, +) + +__all__ = [ + "Transform", + "Compose", + "TokenizeMessages", +] diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py new file mode 100644 index 000000000..b4df54e89 --- /dev/null +++ b/torchtune/modules/transforms/_transforms.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Protocol + +from torchtune.data import Message +from torchtune.modules.tokenizers import ModelTokenizer + + +class Transform(Protocol): + """ + Loose interface for all data and model transforms. Transforms operate at the + sample level and perform operations on a sample dict which is contained in + kwargs. Any fields that will be processed are unfolded with explicit keyword-arguments, + then the updated dict is returned. + """ + + def __call__(self, **kwargs) -> Mapping[str, Any]: + pass + + +class Compose(Transform): + """ + Compose multiple transforms together, inspired by torchvision's ``Compose`` API + + Args: + transforms (List[Transform]): List of transforms to compose together in sequential order. + """ + + def __init__(self, transforms: List[Transform]) -> None: + self.transforms = transforms + + def __call__(self, **kwargs) -> Mapping[str, Any]: + for transform in self.transforms: + kwargs = transform(**kwargs) + return kwargs + + +class TokenizeMessages(Transform): + """ + Apply the ``tokenize_messages`` method from a given + :class:`~torchtune.modules.tokenizers.ModelTokenizer` on the ``messages`` field of the sample. + + Args: + tokenizer (ModelTokenizer): Tokenizer used by the model that implements + the ``tokenize_messages`` method. + """ + + def __init__(self, tokenizer: ModelTokenizer, max_seq_len: int) -> None: + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + + def __call__(self, *, messages: List[Message], **kwargs) -> Mapping[str, Any]: + tokens, mask = self.tokenizer.tokenize_messages( + messages, max_seq_len=self.max_seq_len + ) + kwargs.update({"tokens": tokens, "mask": mask}) + return kwargs From 5f7e8aa4d87258df8c15a4973161745aa7117a53 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 3 Jul 2024 09:21:38 -0700 Subject: [PATCH 2/9] add xattn mask transform --- torchtune/modules/transforms/_transforms.py | 81 +++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index b4df54e89..a4bee9e15 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -59,3 +59,84 @@ def __call__(self, *, messages: List[Message], **kwargs) -> Mapping[str, Any]: ) kwargs.update({"tokens": tokens, "mask": mask}) return kwargs + + +class CrossAttentionMask(Transform): + """ + Computes the cross-attention mask for text + image inputs. Text tokens that + participate in cross-attention with an image token will show True in the mask + and follow these rules: + 1) Text tokens immediately following the image token up until the next image token + 2) Consecutive image tokens attend to all subsequent text tokens + + Resultant mask is of shape (text_seq_len, image_seq_len), where True indicates + that the token outputted from the image encoder attends to the token in the + text sequence. + + Args: + num_patches (int): Number of patches per image, excluding class token. + image_token_id (int): Token ID of the image special token. + """ + + def __init__(self, num_patches: int, image_token_id: int): + self.num_patches = num_patches + self.image_token_id = image_token_id + + def _get_image_attention_intervals( + self, tokens: List[int] + ) -> List[Tuple[int, int]]: + """ + Returns a list of tuples of the form (start, end) where start is the index + of the current image token and end is the index of the next image token, exclusive. + If the image token attends until the end of the sequence, end will be -1. + """ + vision_token_locations = [ + i for i, token in enumerate(tokens) if token == self.image_token_id + ] + # Return empty list if there are no images + if len(vision_token_locations) == 0: + return [] + # If there is only one image, it will attend to subsequent text until end + if len(vision_token_locations) == 1: + return [[vision_token_locations[0], -1]] + + vision_masks = [ + [tok1, tok2] + for tok1, tok2 in zip( + vision_token_locations[:-1], vision_token_locations[1:] + ) + ] + # Last image will attend to subsequent text until end + vision_masks.append([vision_token_locations[-1], -1]) + + # If there are consecutive vision tokens, they should all attend to the + # same subsequent text + last_mask_end = vision_masks[-1][1] + for vision_mask in vision_masks[::-1]: + if vision_mask[0] == vision_mask[1] - 1: + vision_mask[1] = last_mask_end + last_mask_end = vision_mask[1] + return vision_masks + + def __call__(self, *, tokens, images, **kwargs): + # We are still at sample level pre-collating + n_img, n_tiles, _, _, _ = images.shape + text_seq_len = len(tokens) + single_image_seq_len = n_tiles * self.num_patches + 1 + image_seq_len = single_image_seq_len * n_img + intervals = self._get_image_attention_intervals(tokens) + assert len(intervals) == n_img + + mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) + for image_num, interval in enumerate(intervals): + start, end = interval + end = text_seq_len if end == -1 else end + mask[ + start:end, + image_num + * single_image_seq_len : (image_num + 1) + * single_image_seq_len, + ] = True + + kwargs.update({"encoder_mask": mask, "tokens": tokens, "images": images}) + return kwargs From 9a280f2358673f60c48c060a555f2fe92d0d147a Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 3 Jul 2024 13:51:43 -0700 Subject: [PATCH 3/9] add transforms test --- tests/torchtune/modules/transforms/test_transforms.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/torchtune/modules/transforms/test_transforms.py diff --git a/tests/torchtune/modules/transforms/test_transforms.py b/tests/torchtune/modules/transforms/test_transforms.py new file mode 100644 index 000000000..958d3546e --- /dev/null +++ b/tests/torchtune/modules/transforms/test_transforms.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Protocol + +from torchtune.data import Message +from torchtune.modules.tokenizers import ModelTokenizer +from torchtune.modules.transforms import Pipeline, TokenizeMessages, CrossAttentionMask From 7f22d58dae9ff03ea4764e26c76e12f251bc5d40 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 8 Jul 2024 17:07:49 -0700 Subject: [PATCH 4/9] only do cross attention mask --- docs/source/api_ref_modules.rst | 1 + .../modules/transforms/test_transforms.py | 70 +++++++- torchtune/modules/transforms/__init__.py | 9 +- torchtune/modules/transforms/_transforms.py | 158 ++++++++++-------- 4 files changed, 159 insertions(+), 79 deletions(-) diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 261ca2b75..632bec06b 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -93,3 +93,4 @@ Functions used for preprocessing images. transforms.resize_with_pad transforms.tile_crop transforms.find_supported_resolutions + transforms.VisionCrossAttentionMask diff --git a/tests/torchtune/modules/transforms/test_transforms.py b/tests/torchtune/modules/transforms/test_transforms.py index 1551abc58..688578098 100644 --- a/tests/torchtune/modules/transforms/test_transforms.py +++ b/tests/torchtune/modules/transforms/test_transforms.py @@ -4,8 +4,70 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Mapping, Protocol +import pytest +import torch +from torchtune.modules.transforms import VisionCrossAttentionMask -from torchtune.data import Message -from torchtune.modules.tokenizers import ModelTokenizer -from torchtune.modules.transforms import CrossAttentionMask, Pipeline, TokenizeMessages + +IMAGE_TOKEN_ID = 1 + + +class TestVisionCrossAttentionMask: + @pytest.fixture + def num_tiles(self): + return 2 + + @pytest.fixture + def tile_size(self): + return 4 + + @pytest.fixture + def patch_size(self): + return 2 + + @pytest.fixture + def image_num_tokens(self, num_tiles, tile_size, patch_size): + return ((tile_size // patch_size) ** 2 + 1) * num_tiles + + @pytest.fixture + def tokens(self): + # This tests image tokens not at start, consecutive images, and image + # with text until end. + # text = 2, image = 1 + return [2, 2, IMAGE_TOKEN_ID, IMAGE_TOKEN_ID, 2, 2, IMAGE_TOKEN_ID, 2, 2] + + @pytest.fixture + def images(self, num_tiles, tokens): + n_img = len([i for i in tokens if i == IMAGE_TOKEN_ID]) + return [torch.ones(num_tiles, 3, 2, 2) for _ in range(n_img)] + + @pytest.fixture + def transform(self, tile_size, patch_size): + # patches per tile = 4 + return VisionCrossAttentionMask( + tile_size=tile_size, + patch_size=patch_size, + image_token_id=IMAGE_TOKEN_ID, + ) + + def test_get_image_attention_intervals(self, transform, tokens): + actual = transform._get_image_attention_intervals(tokens) + expected = [[2, 6], [3, 6], [6, 9]] + assert actual == expected + + def test_call(self, transform, tokens, images, image_num_tokens): + dummy_kwargs = {"hello": 8} + actual = transform(tokens=tokens, images=images, **dummy_kwargs) + expected = [ + torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool) + for _ in range(len(images)) + ] + expected[0][2:6, :] = True + expected[1][3:6, :] = True + expected[2][6:9, :] = True + for i in range(len(images)): + torch.testing.assert_close(actual["encoder_mask"][i], expected[i]) + torch.testing.assert_close(actual["images"][i], images[i]) + + assert actual["tokens"] == tokens + assert actual["hello"] == dummy_kwargs["hello"] diff --git a/torchtune/modules/transforms/__init__.py b/torchtune/modules/transforms/__init__.py index 637557fe1..c317e7d7c 100644 --- a/torchtune/modules/transforms/__init__.py +++ b/torchtune/modules/transforms/__init__.py @@ -4,11 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.modules.transforms._transforms import ( - Compose, - TokenizeMessages, - Transform, -) +from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( # noqa find_supported_resolutions, get_canvas_best_fit, @@ -20,10 +16,9 @@ __all__ = [ "Transform", - "Compose", - "TokenizeMessages", "get_canvas_best_fit", "resize_with_pad", "tile_crop", "find_supported_resolutions", + "VisionCrossAttentionMask", ] diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index a4bee9e15..e71f5222e 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -4,10 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Mapping, Protocol +from typing import Any, List, Mapping, Protocol, Tuple -from torchtune.data import Message -from torchtune.modules.tokenizers import ModelTokenizer +import torch class Transform(Protocol): @@ -22,64 +21,45 @@ def __call__(self, **kwargs) -> Mapping[str, Any]: pass -class Compose(Transform): - """ - Compose multiple transforms together, inspired by torchvision's ``Compose`` API - - Args: - transforms (List[Transform]): List of transforms to compose together in sequential order. - """ - - def __init__(self, transforms: List[Transform]) -> None: - self.transforms = transforms - - def __call__(self, **kwargs) -> Mapping[str, Any]: - for transform in self.transforms: - kwargs = transform(**kwargs) - return kwargs - - -class TokenizeMessages(Transform): - """ - Apply the ``tokenize_messages`` method from a given - :class:`~torchtune.modules.tokenizers.ModelTokenizer` on the ``messages`` field of the sample. - - Args: - tokenizer (ModelTokenizer): Tokenizer used by the model that implements - the ``tokenize_messages`` method. +class VisionCrossAttentionMask(Transform): """ + Computes the cross-attention mask for text + image inputs. Text tokens that + participate in cross-attention with an image token will show True in the mask + and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper + (https://arxiv.org/pdf/2204.14198): + 1) Text tokens immediately following the image token up until the next image token + 2) Consecutive image tokens attend to subsequent text tokens - def __init__(self, tokenizer: ModelTokenizer, max_seq_len: int) -> None: - self.tokenizer = tokenizer - self.max_seq_len = max_seq_len + :: - def __call__(self, *, messages: List[Message], **kwargs) -> Mapping[str, Any]: - tokens, mask = self.tokenizer.tokenize_messages( - messages, max_seq_len=self.max_seq_len - ) - kwargs.update({"tokens": tokens, "mask": mask}) - return kwargs + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ + img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ + └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ + These are two dogs. This is a cat. -class CrossAttentionMask(Transform): - """ - Computes the cross-attention mask for text + image inputs. Text tokens that - participate in cross-attention with an image token will show True in the mask - and follow these rules: - 1) Text tokens immediately following the image token up until the next image token - 2) Consecutive image tokens attend to all subsequent text tokens Resultant mask is of shape (text_seq_len, image_seq_len), where True indicates that the token outputted from the image encoder attends to the token in the text sequence. Args: - num_patches (int): Number of patches per image, excluding class token. + tile_size (int): The size of the image tiles from the image transform + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. image_token_id (int): Token ID of the image special token. """ - def __init__(self, num_patches: int, image_token_id: int): - self.num_patches = num_patches + def __init__(self, tile_size: int, patch_size: int, image_token_id: int): + patch_grid_size = tile_size // patch_size + self.patches_per_tile = patch_grid_size**2 self.image_token_id = image_token_id def _get_image_attention_intervals( @@ -88,8 +68,24 @@ def _get_image_attention_intervals( """ Returns a list of tuples of the form (start, end) where start is the index of the current image token and end is the index of the next image token, exclusive. - If the image token attends until the end of the sequence, end will be -1. + + Args: + tokens (List[int]): List of token IDs in the text sequence + + Returns: + List[Tuple[int, int]]: List of tuples of the form [start, end) indicating + range of positions in text sequence that should attend to the image + + Example: + >>> text = "These are two dogs. This is a cat." + >>> image_token_id = 1 + >>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] + >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) + >>> intervals = transform._get_image_attention_intervals(tokens) + >>> print(intervals) + [(0, 7), (1, 7), (7, 12)] """ + end = len(tokens) vision_token_locations = [ i for i, token in enumerate(tokens) if token == self.image_token_id ] @@ -98,16 +94,18 @@ def _get_image_attention_intervals( return [] # If there is only one image, it will attend to subsequent text until end if len(vision_token_locations) == 1: - return [[vision_token_locations[0], -1]] + return [[vision_token_locations[0], end]] + # Construct intervals from previous image token to next image token vision_masks = [ - [tok1, tok2] - for tok1, tok2 in zip( + [tok_idx_prev, tok_idx_next] + # Offset by one to get consecutive indices + for tok_idx_prev, tok_idx_next in zip( vision_token_locations[:-1], vision_token_locations[1:] ) ] # Last image will attend to subsequent text until end - vision_masks.append([vision_token_locations[-1], -1]) + vision_masks.append([vision_token_locations[-1], end]) # If there are consecutive vision tokens, they should all attend to the # same subsequent text @@ -118,25 +116,49 @@ def _get_image_attention_intervals( last_mask_end = vision_mask[1] return vision_masks - def __call__(self, *, tokens, images, **kwargs): - # We are still at sample level pre-collating - n_img, n_tiles, _, _, _ = images.shape - text_seq_len = len(tokens) - single_image_seq_len = n_tiles * self.num_patches + 1 - image_seq_len = single_image_seq_len * n_img + def __call__(self, *, tokens: List[int], images: List[torch.Tensor], **kwargs): + """ + Generates the vision cross-attention mask for the given sample based on + the image token locations interleaved in the text sequence. + + Args: + tokens (List[int]): List of token IDs in the text sequence. Number of + image token IDs in the sequence must match the number of images. + images (List[torch.Tensor]): List of image Tensors post-tiling of shape + (n_tiles, c, h, w) each. + **kwargs (Dict[str, Any]): all other keys within the sample that will + not be altered by this transform. + + Returns: + Dict[str, Any]: updated sample with the following keys: + - encoder_mask (List[torch.Tensor]): masks of shape (text_seq_len, image_seq_len) + - tokens (List[int]): original tokens + - images (List[torch.Tensor]): original images + """ + + # One sample can have multiple images - verify the number of image tokens + # is the same + n_img = len(images) intervals = self._get_image_attention_intervals(tokens) assert len(intervals) == n_img - mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) + # Create mask for each individual image based on its number of tokens, + # which can vary based on number of tiles. The masks are padded and concatenated + # together in the batch collator + text_seq_len = len(tokens) + masks = [] for image_num, interval in enumerate(intervals): + # Identify what part of text sequence should be attended start, end = interval - end = text_seq_len if end == -1 else end - mask[ - start:end, - image_num - * single_image_seq_len : (image_num + 1) - * single_image_seq_len, - ] = True - - kwargs.update({"encoder_mask": mask, "tokens": tokens, "images": images}) + # Compute this image's number of tokens based on num tiles, patches per tile + n_tiles = images[image_num].shape[0] + image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token + # Mask will be block of 1s at the corresponding interval in the text. + # It is not a causal block because all the image tokens correspond + # to a single image, so text tokens attend to all the image's tokens + mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) + mask[start:end, :] = True + masks.append(mask) + + kwargs.update({"encoder_mask": masks, "tokens": tokens, "images": images}) return kwargs From edfec5a1fd28e8bce4ba31a4f41d151e79515d76 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 9 Jul 2024 13:42:00 -0700 Subject: [PATCH 5/9] address comments --- .../modules/transforms/test_transforms.py | 21 +++++--- torchtune/modules/transforms/_transforms.py | 54 ++++++++++--------- 2 files changed, 42 insertions(+), 33 deletions(-) diff --git a/tests/torchtune/modules/transforms/test_transforms.py b/tests/torchtune/modules/transforms/test_transforms.py index 688578098..0436a34ec 100644 --- a/tests/torchtune/modules/transforms/test_transforms.py +++ b/tests/torchtune/modules/transforms/test_transforms.py @@ -38,11 +38,16 @@ def tokens(self): @pytest.fixture def images(self, num_tiles, tokens): - n_img = len([i for i in tokens if i == IMAGE_TOKEN_ID]) - return [torch.ones(num_tiles, 3, 2, 2) for _ in range(n_img)] + n_img = len([token_id for token_id in tokens if token_id == IMAGE_TOKEN_ID]) + n_channels = 3 + tile_size = 2 + return [ + torch.ones(num_tiles, n_channels, tile_size, tile_size) + for _ in range(n_img) + ] @pytest.fixture - def transform(self, tile_size, patch_size): + def cross_attn_mask_transform(self, tile_size, patch_size): # patches per tile = 4 return VisionCrossAttentionMask( tile_size=tile_size, @@ -50,14 +55,16 @@ def transform(self, tile_size, patch_size): image_token_id=IMAGE_TOKEN_ID, ) - def test_get_image_attention_intervals(self, transform, tokens): - actual = transform._get_image_attention_intervals(tokens) + def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens): + actual = cross_attn_mask_transform._get_image_attention_intervals(tokens) expected = [[2, 6], [3, 6], [6, 9]] assert actual == expected - def test_call(self, transform, tokens, images, image_num_tokens): + def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens): + sample = {"tokens": tokens, "images": images} dummy_kwargs = {"hello": 8} - actual = transform(tokens=tokens, images=images, **dummy_kwargs) + sample.update(dummy_kwargs) + actual = cross_attn_mask_transform(sample) expected = [ torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool) for _ in range(len(images)) diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index e71f5222e..5f67988b5 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Mapping, Protocol, Tuple +from typing import Any, Dict, List, Mapping, Protocol import torch @@ -27,8 +27,8 @@ class VisionCrossAttentionMask(Transform): participate in cross-attention with an image token will show True in the mask and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper (https://arxiv.org/pdf/2204.14198): - 1) Text tokens immediately following the image token up until the next image token - 2) Consecutive image tokens attend to subsequent text tokens + 1) Text tokens immediately following the image token up until the next image token + 2) Consecutive image tokens attend to subsequent text tokens :: @@ -45,9 +45,10 @@ class VisionCrossAttentionMask(Transform): - Resultant mask is of shape (text_seq_len, image_seq_len), where True indicates - that the token outputted from the image encoder attends to the token in the - text sequence. + Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), + where True indicates that the token outputted from the image encoder attends + to the token in the text sequence in cross-attention. A list of these masks + are returned with length equal to number of images in the sample. Args: tile_size (int): The size of the image tiles from the image transform @@ -62,18 +63,16 @@ def __init__(self, tile_size: int, patch_size: int, image_token_id: int): self.patches_per_tile = patch_grid_size**2 self.image_token_id = image_token_id - def _get_image_attention_intervals( - self, tokens: List[int] - ) -> List[Tuple[int, int]]: + def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int, int]]: """ - Returns a list of tuples of the form (start, end) where start is the index + Returns a list of lists of the form [start, end) where start is the index of the current image token and end is the index of the next image token, exclusive. Args: tokens (List[int]): List of token IDs in the text sequence Returns: - List[Tuple[int, int]]: List of tuples of the form [start, end) indicating + List[List[int, int]]: List of lists of the form [start, end) indicating range of positions in text sequence that should attend to the image Example: @@ -83,7 +82,7 @@ def _get_image_attention_intervals( >>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) >>> intervals = transform._get_image_attention_intervals(tokens) >>> print(intervals) - [(0, 7), (1, 7), (7, 12)] + [[0, 7], [1, 7], [7, 12]] """ end = len(tokens) vision_token_locations = [ @@ -116,35 +115,38 @@ def _get_image_attention_intervals( last_mask_end = vision_mask[1] return vision_masks - def __call__(self, *, tokens: List[int], images: List[torch.Tensor], **kwargs): + def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: """ Generates the vision cross-attention mask for the given sample based on the image token locations interleaved in the text sequence. Args: - tokens (List[int]): List of token IDs in the text sequence. Number of - image token IDs in the sequence must match the number of images. - images (List[torch.Tensor]): List of image Tensors post-tiling of shape - (n_tiles, c, h, w) each. - **kwargs (Dict[str, Any]): all other keys within the sample that will - not be altered by this transform. + sample (Dict[str, Any]): Sample dict containing the following keys: + - tokens (List[int]): List of token IDs in the text sequence. Number of + image token IDs in the sequence must match the number of images. + - images (List[torch.Tensor]): List of image Tensors post-tiling of shape + (n_tiles, c, h, w) each. Returns: Dict[str, Any]: updated sample with the following keys: - - encoder_mask (List[torch.Tensor]): masks of shape (text_seq_len, image_seq_len) + - encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len), + where length of list == number of images in sample - tokens (List[int]): original tokens - images (List[torch.Tensor]): original images """ - + tokens, images = sample["tokens"], sample["images"] # One sample can have multiple images - verify the number of image tokens # is the same n_img = len(images) intervals = self._get_image_attention_intervals(tokens) - assert len(intervals) == n_img + if len(intervals) != n_img: + raise RuntimeError( + f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." + ) # Create mask for each individual image based on its number of tokens, - # which can vary based on number of tiles. The masks are padded and concatenated - # together in the batch collator + # which can vary based on number of tiles since they are not yet tile padded. + # The masks are padded and concatenated together in the batch collator text_seq_len = len(tokens) masks = [] for image_num, interval in enumerate(intervals): @@ -160,5 +162,5 @@ def __call__(self, *, tokens: List[int], images: List[torch.Tensor], **kwargs): mask[start:end, :] = True masks.append(mask) - kwargs.update({"encoder_mask": masks, "tokens": tokens, "images": images}) - return kwargs + sample.update({"encoder_mask": masks}) + return sample From f4708376bc1793cd34cf9f035ef9649b6b01ee49 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 9 Jul 2024 13:46:09 -0700 Subject: [PATCH 6/9] fix typing --- torchtune/modules/transforms/_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index 5f67988b5..26362ec2b 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -63,7 +63,7 @@ def __init__(self, tile_size: int, patch_size: int, image_token_id: int): self.patches_per_tile = patch_grid_size**2 self.image_token_id = image_token_id - def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int, int]]: + def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: """ Returns a list of lists of the form [start, end) where start is the index of the current image token and end is the index of the next image token, exclusive. @@ -72,7 +72,7 @@ def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int, in tokens (List[int]): List of token IDs in the text sequence Returns: - List[List[int, int]]: List of lists of the form [start, end) indicating + List[List[int]]: List of lists of the form [start, end) indicating range of positions in text sequence that should attend to the image Example: From 04d3f9872f55dcb073220e63d7aeec7dbc895536 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 9 Jul 2024 13:54:55 -0700 Subject: [PATCH 7/9] fix docs --- torchtune/modules/transforms/_transforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index 26362ec2b..1cf3b87fa 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -27,8 +27,9 @@ class VisionCrossAttentionMask(Transform): participate in cross-attention with an image token will show True in the mask and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper (https://arxiv.org/pdf/2204.14198): - 1) Text tokens immediately following the image token up until the next image token - 2) Consecutive image tokens attend to subsequent text tokens + + (1) Text tokens immediately following the image token up until the next image token + (2) Consecutive image tokens attend to subsequent text tokens :: From f3ca1c6ab4deb3e7813c4e8e9ba78418d94c2880 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 9 Jul 2024 13:58:34 -0700 Subject: [PATCH 8/9] fix lint --- torchtune/modules/transforms/_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index 1cf3b87fa..920b6d019 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -27,7 +27,7 @@ class VisionCrossAttentionMask(Transform): participate in cross-attention with an image token will show True in the mask and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper (https://arxiv.org/pdf/2204.14198): - + (1) Text tokens immediately following the image token up until the next image token (2) Consecutive image tokens attend to subsequent text tokens From bead59a56ab374bcd6355f5e40f09ca4db812708 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Tue, 9 Jul 2024 15:09:37 -0700 Subject: [PATCH 9/9] update typing --- torchtune/modules/transforms/_transforms.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index 920b6d019..68142686f 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict, List, Mapping, Protocol +from typing import Any, List, Mapping, Protocol import torch @@ -12,12 +12,10 @@ class Transform(Protocol): """ Loose interface for all data and model transforms. Transforms operate at the - sample level and perform operations on a sample dict which is contained in - kwargs. Any fields that will be processed are unfolded with explicit keyword-arguments, - then the updated dict is returned. + sample level and perform operations on a sample dict, returning the updated dict. """ - def __call__(self, **kwargs) -> Mapping[str, Any]: + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: pass @@ -116,20 +114,20 @@ def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: last_mask_end = vision_mask[1] return vision_masks - def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: """ Generates the vision cross-attention mask for the given sample based on the image token locations interleaved in the text sequence. Args: - sample (Dict[str, Any]): Sample dict containing the following keys: + sample (Mapping[str, Any]): Sample dict containing the following keys: - tokens (List[int]): List of token IDs in the text sequence. Number of image token IDs in the sequence must match the number of images. - images (List[torch.Tensor]): List of image Tensors post-tiling of shape (n_tiles, c, h, w) each. Returns: - Dict[str, Any]: updated sample with the following keys: + Mapping[str, Any]: updated sample with the following keys: - encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len), where length of list == number of images in sample - tokens (List[int]): original tokens