Skip to content

Commit

Permalink
wrap up NaViT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 25, 2023
1 parent 32974c3 commit 6e2393d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 3 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,24 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov

```

Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length

```python
images = [
torch.randn(3, 256, 256),
torch.randn(3, 128, 128),
torch.randn(3, 128, 256),
torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]

preds = v(
images,
group_images = True,
group_max_seq_len = 64
) # (5, 1000)
```

## Distillation

<img src="./images/distill.png" width="300px"></img>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.8',
version = '1.2.9',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
62 changes: 60 additions & 2 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import List
from typing import List, Union

import torch
import torch.nn.functional as F
Expand All @@ -17,12 +17,58 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def always(val):
return lambda *args: val

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def divisible_by(numer, denom):
return (numer % denom) == 0

# auto grouping images

def group_images_by_max_seq_len(
images: List[Tensor],
patch_size: int,
calc_token_dropout = None,
max_seq_len = 2048

) -> List[List[Tensor]]:

calc_token_dropout = default(calc_token_dropout, always(0.))

groups = []
group = []
seq_len = 0

if isinstance(calc_token_dropout, (float, int)):
calc_token_dropout = always(calc_token_dropout)

for image in images:
assert isinstance(image, Tensor)

image_dims = image.shape[-2:]
ph, pw = map(lambda t: t // patch_size, image_dims)

image_seq_len = (ph * pw)
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))

assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'

if (seq_len + image_seq_len) > max_seq_len:
groups.append(group)
group = []
seq_len = 0

group.append(image)
seq_len += image_seq_len

if len(group) > 0:
groups.append(group)

return groups

# normalization
# they use layernorm without bias, something that pytorch does not offer

Expand Down Expand Up @@ -199,13 +245,25 @@ def device(self):

def forward(
self,
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
):
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)

arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)

# auto pack if specified

if group_images:
batched_images = group_images_by_max_seq_len(
batched_images,
patch_size = self.patch_size,
calc_token_dropout = self.calc_token_dropout,
max_seq_len = group_max_seq_len
)

# process images into variable lengthed sequences with attention mask

num_images = []
Expand Down

0 comments on commit 6e2393d

Please sign in to comment.