Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision cross attention mask transform #1141

Merged
merged 11 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,4 @@ Functions used for preprocessing images.
transforms.resize_with_pad
transforms.tile_crop
transforms.find_supported_resolutions
transforms.VisionCrossAttentionMask
80 changes: 80 additions & 0 deletions tests/torchtune/modules/transforms/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from torchtune.modules.transforms import VisionCrossAttentionMask


IMAGE_TOKEN_ID = 1


class TestVisionCrossAttentionMask:
@pytest.fixture
def num_tiles(self):
return 2

@pytest.fixture
def tile_size(self):
return 4

@pytest.fixture
def patch_size(self):
return 2

@pytest.fixture
def image_num_tokens(self, num_tiles, tile_size, patch_size):
return ((tile_size // patch_size) ** 2 + 1) * num_tiles

@pytest.fixture
def tokens(self):
# This tests image tokens not at start, consecutive images, and image
# with text until end.
# text = 2, image = 1
return [2, 2, IMAGE_TOKEN_ID, IMAGE_TOKEN_ID, 2, 2, IMAGE_TOKEN_ID, 2, 2]

@pytest.fixture
def images(self, num_tiles, tokens):
n_img = len([token_id for token_id in tokens if token_id == IMAGE_TOKEN_ID])
n_channels = 3
tile_size = 2
return [
torch.ones(num_tiles, n_channels, tile_size, tile_size)
for _ in range(n_img)
]

@pytest.fixture
def cross_attn_mask_transform(self, tile_size, patch_size):
# patches per tile = 4
return VisionCrossAttentionMask(
tile_size=tile_size,
patch_size=patch_size,
image_token_id=IMAGE_TOKEN_ID,
)

def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens):
actual = cross_attn_mask_transform._get_image_attention_intervals(tokens)
expected = [[2, 6], [3, 6], [6, 9]]
assert actual == expected

def test_call(self, cross_attn_mask_transform, tokens, images, image_num_tokens):
sample = {"tokens": tokens, "images": images}
dummy_kwargs = {"hello": 8}
sample.update(dummy_kwargs)
actual = cross_attn_mask_transform(sample)
expected = [
torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool)
for _ in range(len(images))
]
expected[0][2:6, :] = True
expected[1][3:6, :] = True
expected[2][6:9, :] = True
for i in range(len(images)):
torch.testing.assert_close(actual["encoder_mask"][i], expected[i])
torch.testing.assert_close(actual["images"][i], images[i])

assert actual["tokens"] == tokens
assert actual["hello"] == dummy_kwargs["hello"]
11 changes: 8 additions & 3 deletions torchtune/modules/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .vision_utils.get_canvas_best_fit import ( # noqa
from torchtune.modules.transforms._transforms import Transform, VisionCrossAttentionMask
from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( # noqa
find_supported_resolutions,
get_canvas_best_fit,
)
from .vision_utils.resize_with_pad import resize_with_pad # noqa
from .vision_utils.tile_crop import tile_crop # noqa
from torchtune.modules.transforms.vision_utils.resize_with_pad import ( # noqa
resize_with_pad,
)
from torchtune.modules.transforms.vision_utils.tile_crop import tile_crop # noqa

__all__ = [
"Transform",
"get_canvas_best_fit",
"resize_with_pad",
"tile_crop",
"find_supported_resolutions",
"VisionCrossAttentionMask",
]
165 changes: 165 additions & 0 deletions torchtune/modules/transforms/_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Mapping, Protocol

import torch


class Transform(Protocol):
"""
Loose interface for all data and model transforms. Transforms operate at the
sample level and perform operations on a sample dict, returning the updated dict.
"""

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
pass


class VisionCrossAttentionMask(Transform):
"""
Computes the cross-attention mask for text + image inputs. Text tokens that
participate in cross-attention with an image token will show True in the mask
and follow the interleaved structure laid out in Fig. 7 of the Flamingo paper
(https://arxiv.org/pdf/2204.14198):

(1) Text tokens immediately following the image token up until the next image token
(2) Consecutive image tokens attend to subsequent text tokens

::

┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img1 │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img2 │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ │ │ │ │ │ │ │ │ │
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
img3 │ │ │ │ │ │ │ │ │ │ │ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │ │ ■ │
└───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘ └───┘
<img1> <img2>These are two dogs. <img3> This is a cat.
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved



Resultant mask is constructed per image and is of shape (text_seq_len, image_seq_len),
where True indicates that the token outputted from the image encoder attends
to the token in the text sequence in cross-attention. A list of these masks
are returned with length equal to number of images in the sample.

Args:
tile_size (int): The size of the image tiles from the image transform
patch_size (int): The size of each patch. Used to divide the tiles into patches.
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
image_token_id (int): Token ID of the image special token.
"""

def __init__(self, tile_size: int, patch_size: int, image_token_id: int):
patch_grid_size = tile_size // patch_size
self.patches_per_tile = patch_grid_size**2
self.image_token_id = image_token_id

def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]:
"""
Returns a list of lists of the form [start, end) where start is the index
of the current image token and end is the index of the next image token, exclusive.

Args:
tokens (List[int]): List of token IDs in the text sequence

Returns:
List[List[int]]: List of lists of the form [start, end) indicating
range of positions in text sequence that should attend to the image

Example:
>>> text = "<img1><img2>These are two dogs. <img3>This is a cat."
>>> image_token_id = 1
>>> tokens = [1, 1, 9673, 527, 1403, 12875, 13, 1, 1115, 374, 264, 8415]
>>> transform = VisionCrossAttentionMask(tile_size=400, patch_size=40, image_token_id=1)
>>> intervals = transform._get_image_attention_intervals(tokens)
>>> print(intervals)
[[0, 7], [1, 7], [7, 12]]
"""
end = len(tokens)
vision_token_locations = [
i for i, token in enumerate(tokens) if token == self.image_token_id
]
# Return empty list if there are no images
if len(vision_token_locations) == 0:
return []
# If there is only one image, it will attend to subsequent text until end
if len(vision_token_locations) == 1:
return [[vision_token_locations[0], end]]
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

# Construct intervals from previous image token to next image token
vision_masks = [
[tok_idx_prev, tok_idx_next]
# Offset by one to get consecutive indices
for tok_idx_prev, tok_idx_next in zip(
vision_token_locations[:-1], vision_token_locations[1:]
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
)
]
# Last image will attend to subsequent text until end
vision_masks.append([vision_token_locations[-1], end])
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved

# If there are consecutive vision tokens, they should all attend to the
# same subsequent text
last_mask_end = vision_masks[-1][1]
for vision_mask in vision_masks[::-1]:
if vision_mask[0] == vision_mask[1] - 1:
vision_mask[1] = last_mask_end
last_mask_end = vision_mask[1]
return vision_masks

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Generates the vision cross-attention mask for the given sample based on
the image token locations interleaved in the text sequence.

Args:
sample (Mapping[str, Any]): Sample dict containing the following keys:
- tokens (List[int]): List of token IDs in the text sequence. Number of
image token IDs in the sequence must match the number of images.
- images (List[torch.Tensor]): List of image Tensors post-tiling of shape
(n_tiles, c, h, w) each.

Returns:
Mapping[str, Any]: updated sample with the following keys:
- encoder_mask (List[torch.Tensor]): list of masks with shape (text_seq_len, image_seq_len),
where length of list == number of images in sample
- tokens (List[int]): original tokens
- images (List[torch.Tensor]): original images
"""
tokens, images = sample["tokens"], sample["images"]
# One sample can have multiple images - verify the number of image tokens
# is the same
n_img = len(images)
intervals = self._get_image_attention_intervals(tokens)
if len(intervals) != n_img:
raise RuntimeError(
f"The number of image tokens ({len(intervals)}) does not match the number of images ({n_img})."
)

# Create mask for each individual image based on its number of tokens,
# which can vary based on number of tiles since they are not yet tile padded.
# The masks are padded and concatenated together in the batch collator
text_seq_len = len(tokens)
masks = []
for image_num, interval in enumerate(intervals):
# Identify what part of text sequence should be attended
start, end = interval
# Compute this image's number of tokens based on num tiles, patches per tile
n_tiles = images[image_num].shape[0]
image_seq_len = n_tiles * (self.patches_per_tile + 1) # +1 for CLS token
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
# Mask will be block of 1s at the corresponding interval in the text.
# It is not a causal block because all the image tokens correspond
# to a single image, so text tokens attend to all the image's tokens
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool)
mask[start:end, :] = True
masks.append(mask)

sample.update({"encoder_mask": masks})
return sample
Loading