-
Notifications
You must be signed in to change notification settings - Fork 645
Vision cross attention mask transform #1141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
d90e5a9
add basic transforms
RdoubleA 5f7e8aa
add xattn mask transform
RdoubleA 9a280f2
add transforms test
RdoubleA c341f6b
Merge branch 'main' into mm_transforms
RdoubleA 7f22d58
only do cross attention mask
RdoubleA 92faf71
Merge branch 'main' into mm_transforms
RdoubleA edfec5a
address comments
RdoubleA f470837
fix typing
RdoubleA 04d3f98
fix docs
RdoubleA f3ca1c6
fix lint
RdoubleA bead59a
update typing
RdoubleA File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from torchtune.modules.transforms import VisionCrossAttentionMask | ||
|
||
|
||
IMAGE_TOKEN_ID = 1 | ||
|
||
|
||
class TestVisionCrossAttentionMask: | ||
@pytest.fixture | ||
def num_tiles(self): | ||
return 2 | ||
|
||
@pytest.fixture | ||
def tile_size(self): | ||
return 4 | ||
|
||
@pytest.fixture | ||
def patch_size(self): | ||
return 2 | ||
|
||
@pytest.fixture | ||
def image_num_tokens(self, num_tiles, tile_size, patch_size): | ||
return ((tile_size // patch_size) ** 2 + 1) * num_tiles | ||
|
||
@pytest.fixture | ||
def tokens(self): | ||
# This tests image tokens not at start, consecutive images, and image | ||
# with text until end. | ||
# text = 2, image = 1 | ||
return [2, 2, IMAGE_TOKEN_ID, IMAGE_TOKEN_ID, 2, 2, IMAGE_TOKEN_ID, 2, 2] | ||
|
||
@pytest.fixture | ||
def images(self, num_tiles, tokens): | ||
n_img = len([token_id for token_id in tokens if token_id == IMAGE_TOKEN_ID]) | ||
n_channels = 3 | ||
tile_size = 2 | ||
return [ | ||
torch.ones(num_tiles, n_channels, tile_size, tile_size) | ||
for _ in range(n_img) | ||
] | ||
|
||
@pytest.fixture | ||
def cross_attn_mask_transform(self, tile_size, patch_size): | ||
# patches per tile = 4 | ||
return VisionCrossAttentionMask( | ||
tile_size=tile_size, | ||
patch_size=patch_size, | ||
image_token_id=IMAGE_TOKEN_ID, | ||
) | ||
|
||
def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens): | ||
actual = cross_attn_mask_transform._get_image_attention_intervals(tokens) | ||
expected = [[2, 6], [3, 6], [6, 9]] | ||
assert actual == expected | ||
|
||
def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens): | ||
sample = {"tokens": tokens, "images": images} | ||
dummy_kwargs = {"hello": 8} | ||
sample.update(dummy_kwargs) | ||
actual = cross_attn_mask_transform(sample) | ||
expected = [ | ||
torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool) | ||
for _ in range(len(images)) | ||
] | ||
expected[0][2:6, :] = True | ||
expected[1][3:6, :] = True | ||
expected[2][6:9, :] = True | ||
for i in range(len(images)): | ||
torch.testing.assert_close(actual["encoder_mask"][i], expected[i]) | ||
torch.testing.assert_close(actual["images"][i], images[i]) | ||
|
||
assert actual["tokens"] == tokens | ||
assert actual["hello"] == dummy_kwargs["hello"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, List, Mapping, Protocol | ||
|
||
import torch | ||
|
||
|
||
class Transform(Protocol): | ||
""" | ||
Loose interface for all data and model transforms. Transforms operate at the | ||
sample level and perform operations on a sample dict, returning the updated dict. | ||
""" | ||
|
||
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | ||
pass | ||
|
||
|
||
class VisionCrossAttentionMask(Transform): | ||
""" | ||
Computes the cross-attention mask for text + image inputs. Text tokens that | ||
participate in cross-attention with an image token will show True in the mask | ||
and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper | ||
(https://arxiv.org/pdf/2204.14198): | ||
|
||
(1) Text tokens immediately following the image token up until the next image token | ||
(2) Consecutive image tokens attend to subsequent text tokens | ||
|
||
:: | ||
|
||
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ | ||
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ | ||
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ | ||
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ | ||
img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │ | ||
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ | ||
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ | ||
img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ | ||
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ | ||
<img1> <img2>These are two dogs. <img3> This is a cat. | ||
|
||
|
||
|
||
Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len), | ||
where True indicates that the token outputted from the image encoder attends | ||
to the token in the text sequence in cross-attention. A list of these masks | ||
are returned with length equal to number of images in the sample. | ||
|
||
Args: | ||
tile_size (int): The size of the image tiles from the image transform | ||
patch_size (int): The size of each patch. Used to divide the tiles into patches. | ||
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches | ||
with shape (40, 40) each. | ||
image_token_id (int): Token ID of the image special token. | ||
""" | ||
|
||
def __init__(self, tile_size: int, patch_size: int, image_token_id: int): | ||
patch_grid_size = tile_size // patch_size | ||
self.patches_per_tile = patch_grid_size**2 | ||
self.image_token_id = image_token_id | ||
|
||
def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: | ||
""" | ||
Returns a list of lists of the form [start, end) where start is the index | ||
of the current image token and end is the index of the next image token, exclusive. | ||
|
||
Args: | ||
tokens (List[int]): List of token IDs in the text sequence | ||
|
||
Returns: | ||
List[List[int]]: List of lists of the form [start, end) indicating | ||
range of positions in text sequence that should attend to the image | ||
|
||
Example: | ||
>>> text = "<img1><img2>These are two dogs. <img3>This is a cat." | ||
>>> image_token_id = 1 | ||
>>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415] | ||
>>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1) | ||
>>> intervals = transform._get_image_attention_intervals(tokens) | ||
>>> print(intervals) | ||
[[0, 7], [1, 7], [7, 12]] | ||
""" | ||
end = len(tokens) | ||
vision_token_locations = [ | ||
i for i, token in enumerate(tokens) if token == self.image_token_id | ||
] | ||
# Return empty list if there are no images | ||
if len(vision_token_locations) == 0: | ||
return [] | ||
# If there is only one image, it will attend to subsequent text until end | ||
if len(vision_token_locations) == 1: | ||
return [[vision_token_locations[0], end]] | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Construct intervals from previous image token to next image token | ||
vision_masks = [ | ||
[tok_idx_prev, tok_idx_next] | ||
# Offset by one to get consecutive indices | ||
for tok_idx_prev, tok_idx_next in zip( | ||
vision_token_locations[:-1], vision_token_locations[1:] | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
] | ||
# Last image will attend to subsequent text until end | ||
vision_masks.append([vision_token_locations[-1], end]) | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# If there are consecutive vision tokens, they should all attend to the | ||
# same subsequent text | ||
last_mask_end = vision_masks[-1][1] | ||
for vision_mask in vision_masks[::-1]: | ||
if vision_mask[0] == vision_mask[1] - 1: | ||
vision_mask[1] = last_mask_end | ||
last_mask_end = vision_mask[1] | ||
return vision_masks | ||
|
||
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: | ||
""" | ||
Generates the vision cross-attention mask for the given sample based on | ||
the image token locations interleaved in the text sequence. | ||
|
||
Args: | ||
sample (Mapping[str, Any]): Sample dict containing the following keys: | ||
- tokens (List[int]): List of token IDs in the text sequence. Number of | ||
image token IDs in the sequence must match the number of images. | ||
- images (List[torch.Tensor]): List of image Tensors post-tiling of shape | ||
(n_tiles, c, h, w) each. | ||
|
||
Returns: | ||
Mapping[str, Any]: updated sample with the following keys: | ||
- encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len), | ||
where length of list == number of images in sample | ||
- tokens (List[int]): original tokens | ||
- images (List[torch.Tensor]): original images | ||
""" | ||
tokens, images = sample["tokens"], sample["images"] | ||
# One sample can have multiple images - verify the number of image tokens | ||
# is the same | ||
n_img = len(images) | ||
intervals = self._get_image_attention_intervals(tokens) | ||
if len(intervals) != n_img: | ||
raise RuntimeError( | ||
f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})." | ||
) | ||
|
||
# Create mask for each individual image based on its number of tokens, | ||
# which can vary based on number of tiles since they are not yet tile padded. | ||
# The masks are padded and concatenated together in the batch collator | ||
text_seq_len = len(tokens) | ||
masks = [] | ||
for image_num, interval in enumerate(intervals): | ||
# Identify what part of text sequence should be attended | ||
start, end = interval | ||
# Compute this image's number of tokens based on num tiles, patches per tile | ||
n_tiles = images[image_num].shape[0] | ||
image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Mask will be block of 1s at the corresponding interval in the text. | ||
# It is not a causal block because all the image tokens correspond | ||
# to a single image, so text tokens attend to all the image's tokens | ||
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) | ||
mask[start:end, :] = True | ||
masks.append(mask) | ||
|
||
sample.update({"encoder_mask": masks}) | ||
return sample |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.