-
Notifications
You must be signed in to change notification settings - Fork 295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vision cross attention mask transform #1141
Changes from 2 commits
d90e5a9
5f7e8aa
9a280f2
c341f6b
7f22d58
92faf71
edfec5a
f470837
04d3f98
f3ca1c6
bead59a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtune.modules.transforms._transforms import ( | ||
Compose, | ||
TokenizeMessages, | ||
Transform, | ||
) | ||
|
||
__all__ = [ | ||
"Transform", | ||
"Compose", | ||
"TokenizeMessages", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, List, Mapping, Protocol | ||
|
||
from torchtune.data import Message | ||
from torchtune.modules.tokenizers import ModelTokenizer | ||
|
||
|
||
class Transform(Protocol): | ||
""" | ||
Loose interface for all data and model transforms. Transforms operate at the | ||
sample level and perform operations on a sample dict which is contained in | ||
kwargs. Any fields that will be processed are unfolded with explicit keyword-arguments, | ||
then the updated dict is returned. | ||
""" | ||
|
||
def __call__(self, **kwargs) -> Mapping[str, Any]: | ||
pass | ||
|
||
|
||
class Compose(Transform): | ||
""" | ||
Compose multiple transforms together, inspired by torchvision's ``Compose`` API | ||
|
||
Args: | ||
transforms (List[Transform]): List of transforms to compose together in sequential order. | ||
""" | ||
|
||
def __init__(self, transforms: List[Transform]) -> None: | ||
self.transforms = transforms | ||
|
||
def __call__(self, **kwargs) -> Mapping[str, Any]: | ||
for transform in self.transforms: | ||
kwargs = transform(**kwargs) | ||
return kwargs | ||
|
||
|
||
class TokenizeMessages(Transform): | ||
""" | ||
Apply the ``tokenize_messages`` method from a given | ||
:class:`~torchtune.modules.tokenizers.ModelTokenizer` on the ``messages`` field of the sample. | ||
|
||
Args: | ||
tokenizer (ModelTokenizer): Tokenizer used by the model that implements | ||
the ``tokenize_messages`` method. | ||
""" | ||
|
||
def __init__(self, tokenizer: ModelTokenizer, max_seq_len: int) -> None: | ||
self.tokenizer = tokenizer | ||
self.max_seq_len = max_seq_len | ||
|
||
def __call__(self, *, messages: List[Message], **kwargs) -> Mapping[str, Any]: | ||
tokens, mask = self.tokenizer.tokenize_messages( | ||
messages, max_seq_len=self.max_seq_len | ||
) | ||
kwargs.update({"tokens": tokens, "mask": mask}) | ||
return kwargs | ||
|
||
|
||
class CrossAttentionMask(Transform): | ||
""" | ||
Computes the cross-attention mask for text + image inputs. Text tokens that | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this CrossAttention is specific for text + images, should we indicate it in the name? Something like MultimodalCrossAttentionMask or VisionTextCrossAttentionMask? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's a good point. Will rename |
||
participate in cross-attention with an image token will show True in the mask | ||
and follow these rules: | ||
1) Text tokens immediately following the image token up until the next image token | ||
2) Consecutive image tokens attend to all subsequent text tokens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make it more visual? Something like this: https://github.com/huggingface/transformers/blob/60bb571e993b7d73257fb64044726b569fef9403/src/transformers/models/llava_next/modeling_llava_next.py#L446 Or a link to the paper + page where they have an image for it |
||
|
||
Resultant mask is of shape (text_seq_len, image_seq_len), where True indicates | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
that the token outputted from the image encoder attends to the token in the | ||
text sequence. | ||
|
||
Args: | ||
num_patches (int): Number of patches per image, excluding class token. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we should add a link to modules/VisionTransformer for a in-depth explanation of what num_patches mean. For better clarity, would it make sense to rename it num_patches_per_tile, since later we multiply it by n_tiles? If we say "number of patches per image", it may be confusing, because an image can have a variable number of patches. later on you say:
So tile != image. Image is a set of tiles. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah sorry, probably my shallow understanding of patches/images/tiles. What I intended was num_patches per tile. If it makes sense I'd like to keep the name consistent with your vision transformer (either patch_grid_size or patch_size maybe?), whichever parameter you use to compute num patches. I also assumed at this point tiles is padded to the max in all the images. Is this incorrect? Where does the padding happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. patches_per_tile is a fixed size, and its calculated as (tile_size // patch_size)**2. What I did in VisionTransformer was to ask the user to pass tile_size and patch_size, and I calculated it for them. The VisionTransformer has a helper function that saves this value: https://github.com/felipemello1/torchtune/blob/f683812626ad4559464840112ddce516487bea5c/torchtune/modules/vision_transformer.py#L249 Maybe get it from the model, or ask for tile_size and patch_size, to avoid user confusion? |
||
image_token_id (int): Token ID of the image special token. | ||
""" | ||
|
||
def __init__(self, num_patches: int, image_token_id: int): | ||
self.num_patches = num_patches | ||
self.image_token_id = image_token_id | ||
|
||
def _get_image_attention_intervals( | ||
self, tokens: List[int] | ||
) -> List[Tuple[int, int]]: | ||
""" | ||
Returns a list of tuples of the form (start, end) where start is the index | ||
of the current image token and end is the index of the next image token, exclusive. | ||
If the image token attends until the end of the sequence, end will be -1. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: should we add Args:,Returns, Examples? |
||
""" | ||
vision_token_locations = [ | ||
i for i, token in enumerate(tokens) if token == self.image_token_id | ||
] | ||
# Return empty list if there are no images | ||
if len(vision_token_locations) == 0: | ||
return [] | ||
# If there is only one image, it will attend to subsequent text until end | ||
if len(vision_token_locations) == 1: | ||
return [[vision_token_locations[0], -1]] | ||
|
||
vision_masks = [ | ||
[tok1, tok2] | ||
for tok1, tok2 in zip( | ||
vision_token_locations[:-1], vision_token_locations[1:] | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
] | ||
# Last image will attend to subsequent text until end | ||
vision_masks.append([vision_token_locations[-1], -1]) | ||
|
||
# If there are consecutive vision tokens, they should all attend to the | ||
# same subsequent text | ||
last_mask_end = vision_masks[-1][1] | ||
for vision_mask in vision_masks[::-1]: | ||
if vision_mask[0] == vision_mask[1] - 1: | ||
vision_mask[1] = last_mask_end | ||
last_mask_end = vision_mask[1] | ||
return vision_masks | ||
|
||
def __call__(self, *, tokens, images, **kwargs): | ||
# We are still at sample level pre-collating | ||
n_img, n_tiles, _, _, _ = images.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Maybe add a comment with the other dimensions, so we know what they are, but keep the "_", so we know they are not used? You said "# We are still at sample level pre-collating" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah please add type and shape info for arguments to |
||
text_seq_len = len(tokens) | ||
single_image_seq_len = n_tiles * self.num_patches + 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe add comment explaining that +1 is for CLS, if thats the case |
||
image_seq_len = single_image_seq_len * n_img | ||
intervals = self._get_image_attention_intervals(tokens) | ||
assert len(intervals) == n_img | ||
RdoubleA marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) | ||
for image_num, interval in enumerate(intervals): | ||
start, end = interval | ||
end = text_seq_len if end == -1 else end | ||
mask[ | ||
start:end, | ||
image_num | ||
* single_image_seq_len : (image_num + 1) | ||
* single_image_seq_len, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: split line differently if linter allows it, as written this is confusing |
||
] = True | ||
|
||
kwargs.update({"encoder_mask": mask, "tokens": tokens, "images": images}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to also update with tokens and images? Isn't this a no-op for those args? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since they are explicit keyword args, they get unfolded from kwargs and you have to add them back in |
||
return kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the torchvision compose has a different behavior, I wonder if it makes sense to change Compose to something else, so users dont get confused with tv.Compose. Maybe "ComposeTransforms"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about Pipeline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own understanding, the main difference with torchvision compose is that we support multiple inputs and multiple outputs here? Can we not just use torchvision compose with a single dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried naming something as Pipeline, as Kartikay said it would confuse people, because it is also used by other libraries :P. I guess sklearn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers our Compose needs to have a slightly different forward signature to unfold dictionary inputs. From torchvision:
but to avoid confusion, I agree should name it something else. Just haven't figured out what yet