Skip to content

Commit bbc48e0

Browse files
authored
Vision cross attention mask transform (#1141)
1 parent 37636a8 commit bbc48e0

File tree

4 files changed

+254
-3
lines changed

4 files changed

+254
-3
lines changed

docs/source/api_ref_modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,4 @@ Functions used for preprocessing images.
9595
transforms.resize_with_pad
9696
transforms.tile_crop
9797
transforms.find_supported_resolutions
98+
transforms.VisionCrossAttentionMask
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
from torchtune.modules.transforms import VisionCrossAttentionMask
10+
11+
12+
IMAGE_TOKEN_ID = 1
13+
14+
15+
class TestVisionCrossAttentionMask:
16+
@pytest.fixture
17+
def num_tiles(self):
18+
return 2
19+
20+
@pytest.fixture
21+
def tile_size(self):
22+
return 4
23+
24+
@pytest.fixture
25+
def patch_size(self):
26+
return 2
27+
28+
@pytest.fixture
29+
def image_num_tokens(self, num_tiles, tile_size, patch_size):
30+
return ((tile_size // patch_size) ** 2 + 1) * num_tiles
31+
32+
@pytest.fixture
33+
def tokens(self):
34+
# This tests image tokens not at start, consecutive images, and image
35+
# with text until end.
36+
# text = 2, image = 1
37+
return [2, 2, IMAGE_TOKEN_ID, IMAGE_TOKEN_ID, 2, 2, IMAGE_TOKEN_ID, 2, 2]
38+
39+
@pytest.fixture
40+
def images(self, num_tiles, tokens):
41+
n_img = len([token_id for token_id in tokens if token_id == IMAGE_TOKEN_ID])
42+
n_channels = 3
43+
tile_size = 2
44+
return [
45+
torch.ones(num_tiles, n_channels, tile_size, tile_size)
46+
for _ in range(n_img)
47+
]
48+
49+
@pytest.fixture
50+
def cross_attn_mask_transform(self, tile_size, patch_size):
51+
# patches per tile = 4
52+
return VisionCrossAttentionMask(
53+
tile_size=tile_size,
54+
patch_size=patch_size,
55+
image_token_id=IMAGE_TOKEN_ID,
56+
)
57+
58+
def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens):
59+
actual = cross_attn_mask_transform._get_image_attention_intervals(tokens)
60+
expected = [[2, 6], [3, 6], [6, 9]]
61+
assert actual == expected
62+
63+
def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens):
64+
sample = {"tokens": tokens, "images": images}
65+
dummy_kwargs = {"hello": 8}
66+
sample.update(dummy_kwargs)
67+
actual = cross_attn_mask_transform(sample)
68+
expected = [
69+
torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool)
70+
for _ in range(len(images))
71+
]
72+
expected[0][2:6, :] = True
73+
expected[1][3:6, :] = True
74+
expected[2][6:9, :] = True
75+
for i in range(len(images)):
76+
torch.testing.assert_close(actual["encoder_mask"][i], expected[i])
77+
torch.testing.assert_close(actual["images"][i], images[i])
78+
79+
assert actual["tokens"] == tokens
80+
assert actual["hello"] == dummy_kwargs["hello"]

torchtune/modules/transforms/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .vision_utils.get_canvas_best_fit import ( # noqa
7+
from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask
8+
from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( # noqa
89
find_supported_resolutions,
910
get_canvas_best_fit,
1011
)
11-
from .vision_utils.resize_with_pad import resize_with_pad # noqa
12-
from .vision_utils.tile_crop import tile_crop # noqa
12+
from torchtune.modules.transforms.vision_utils.resize_with_pad import ( # noqa
13+
resize_with_pad,
14+
)
15+
from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop # noqa
1316

1417
__all__ = [
18+
"Transform",
1519
"get_canvas_best_fit",
1620
"resize_with_pad",
1721
"tile_crop",
1822
"find_supported_resolutions",
23+
"VisionCrossAttentionMask",
1924
]
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, List, Mapping, Protocol
8+
9+
import torch
10+
11+
12+
class Transform(Protocol):
13+
"""
14+
Loose interface for all data and model transforms. Transforms operate at the
15+
sample level and perform operations on a sample dict, returning the updated dict.
16+
"""
17+
18+
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
19+
pass
20+
21+
22+
class VisionCrossAttentionMask(Transform):
23+
"""
24+
Computes the cross-attention mask for text + image inputs. Text tokens that
25+
participate in cross-attention with an image token will show True in the mask
26+
and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper
27+
(https://arxiv.org/pdf/2204.14198):
28+
29+
(1) Text tokens immediately following the image token up until the next image token
30+
(2) Consecutive image tokens attend to subsequent text tokens
31+
32+
::
33+
34+
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
35+
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │
36+
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
37+
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
38+
img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │
39+
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
40+
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
41+
img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
42+
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
43+
<img1> <img2>These are two dogs. <img3> This is a cat.
44+
45+
46+
47+
Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len),
48+
where True indicates that the token outputted from the image encoder attends
49+
to the token in the text sequence in cross-attention. A list of these masks
50+
are returned with length equal to number of images in the sample.
51+
52+
Args:
53+
tile_size (int): The size of the image tiles from the image transform
54+
patch_size (int): The size of each patch. Used to divide the tiles into patches.
55+
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches
56+
with shape (40, 40) each.
57+
image_token_id (int): Token ID of the image special token.
58+
"""
59+
60+
def __init__(self, tile_size: int, patch_size: int, image_token_id: int):
61+
patch_grid_size = tile_size // patch_size
62+
self.patches_per_tile = patch_grid_size**2
63+
self.image_token_id = image_token_id
64+
65+
def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]:
66+
"""
67+
Returns a list of lists of the form [start, end) where start is the index
68+
of the current image token and end is the index of the next image token, exclusive.
69+
70+
Args:
71+
tokens (List[int]): List of token IDs in the text sequence
72+
73+
Returns:
74+
List[List[int]]: List of lists of the form [start, end) indicating
75+
range of positions in text sequence that should attend to the image
76+
77+
Example:
78+
>>> text = "<img1><img2>These are two dogs. <img3>This is a cat."
79+
>>> image_token_id = 1
80+
>>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415]
81+
>>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1)
82+
>>> intervals = transform._get_image_attention_intervals(tokens)
83+
>>> print(intervals)
84+
[[0, 7], [1, 7], [7, 12]]
85+
"""
86+
end = len(tokens)
87+
vision_token_locations = [
88+
i for i, token in enumerate(tokens) if token == self.image_token_id
89+
]
90+
# Return empty list if there are no images
91+
if len(vision_token_locations) == 0:
92+
return []
93+
# If there is only one image, it will attend to subsequent text until end
94+
if len(vision_token_locations) == 1:
95+
return [[vision_token_locations[0], end]]
96+
97+
# Construct intervals from previous image token to next image token
98+
vision_masks = [
99+
[tok_idx_prev, tok_idx_next]
100+
# Offset by one to get consecutive indices
101+
for tok_idx_prev, tok_idx_next in zip(
102+
vision_token_locations[:-1], vision_token_locations[1:]
103+
)
104+
]
105+
# Last image will attend to subsequent text until end
106+
vision_masks.append([vision_token_locations[-1], end])
107+
108+
# If there are consecutive vision tokens, they should all attend to the
109+
# same subsequent text
110+
last_mask_end = vision_masks[-1][1]
111+
for vision_mask in vision_masks[::-1]:
112+
if vision_mask[0] == vision_mask[1] - 1:
113+
vision_mask[1] = last_mask_end
114+
last_mask_end = vision_mask[1]
115+
return vision_masks
116+
117+
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
118+
"""
119+
Generates the vision cross-attention mask for the given sample based on
120+
the image token locations interleaved in the text sequence.
121+
122+
Args:
123+
sample (Mapping[str, Any]): Sample dict containing the following keys:
124+
- tokens (List[int]): List of token IDs in the text sequence. Number of
125+
image token IDs in the sequence must match the number of images.
126+
- images (List[torch.Tensor]): List of image Tensors post-tiling of shape
127+
(n_tiles, c, h, w) each.
128+
129+
Returns:
130+
Mapping[str, Any]: updated sample with the following keys:
131+
- encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len),
132+
where length of list == number of images in sample
133+
- tokens (List[int]): original tokens
134+
- images (List[torch.Tensor]): original images
135+
"""
136+
tokens, images = sample["tokens"], sample["images"]
137+
# One sample can have multiple images - verify the number of image tokens
138+
# is the same
139+
n_img = len(images)
140+
intervals = self._get_image_attention_intervals(tokens)
141+
if len(intervals) != n_img:
142+
raise RuntimeError(
143+
f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})."
144+
)
145+
146+
# Create mask for each individual image based on its number of tokens,
147+
# which can vary based on number of tiles since they are not yet tile padded.
148+
# The masks are padded and concatenated together in the batch collator
149+
text_seq_len = len(tokens)
150+
masks = []
151+
for image_num, interval in enumerate(intervals):
152+
# Identify what part of text sequence should be attended
153+
start, end = interval
154+
# Compute this image's number of tokens based on num tiles, patches per tile
155+
n_tiles = images[image_num].shape[0]
156+
image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token
157+
# Mask will be block of 1s at the corresponding interval in the text.
158+
# It is not a causal block because all the image tokens correspond
159+
# to a single image, so text tokens attend to all the image's tokens
160+
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool)
161+
mask[start:end, :] = True
162+
masks.append(mask)
163+
164+
sample.update({"encoder_mask": masks})
165+
return sample

0 commit comments

Comments
 (0)