1+ from collections import deque
12from collections .abc import Iterable
23from functools import partial
3- from itertools import islice , cycle , product
4+ from itertools import islice , cycle
45
56import torch
67from torch import nn , einsum
@@ -103,18 +104,30 @@ def __init__(self, fn, image_size, seq_len):
103104 self .fn = fn
104105 self .image_size = image_size
105106 self .seq_len = seq_len
107+ self .img_seq_len = image_size ** 2
108+ self .text_len = seq_len - self .img_seq_len + 1
106109
107110 def forward (self , x , cache = None , cache_key = None , ** kwargs ):
108- n0 = x .shape [1 ]
109- if exists (cache ):
110- if cache_key in cache :
111- x = torch .cat ([cache [cache_key ], x ], dim = - 2 )
112- cache [cache_key ] = x
111+ seq_len , image_size , text_len = self .seq_len , self .image_size , self .text_len
112+
113+ if exists (cache ) and cache_key in cache :
114+ offset = cache ['offset' ]
115+ assert offset >= text_len , "cached inference for text is not supported"
116+ q = cache [cache_key ]
117+ assert isinstance (q , deque ) and len (q ) == image_size
118+
119+ x_top , x_left , * x_pass = x [:, - 1 ].chunk (4 , dim = - 1 )
120+
121+ q .append ((x_top , x_left ))
122+ x_top = q .popleft ()[0 ]
123+ x_left = q [- 2 ][1 ]
124+ if (offset - text_len ) % image_size == 0 :
125+ x_left = torch .zeros_like (x_left )
126+
127+ x = torch .cat ((x_top , x_left , * x_pass ), dim = - 1 )
128+ return self .fn (x [:, None ], cache = cache , ** kwargs )
113129
114130 n = x .shape [1 ]
115- seq_len , image_size = self .seq_len , self .image_size
116- img_seq_len = image_size ** 2
117- text_len = seq_len - img_seq_len + 1
118131 padding = seq_len - n + 1
119132
120133 # get text and image tokens
@@ -139,8 +152,22 @@ def forward(self, x, cache=None, cache_key=None, **kwargs):
139152 # merge text and image sequence back together
140153
141154 x_img = rearrange (x_img , 'b h w d -> b (h w) d' )
142- x = torch .cat ((x_text , x_img [:, :- padding ]), dim = 1 )
143- return self .fn (x [:, - n0 :], cache = cache , ** kwargs )
155+ x_img = x_img [:, :- padding ]
156+ x = torch .cat ((x_text , x_img ), dim = 1 )
157+
158+ if exists (cache ):
159+ dummy_top , dummy_left , * _ = x [:, - 1 ].chunk (4 , dim = - 1 )
160+ dummy_top , dummy_left = torch .zeros_like (dummy_top ), torch .zeros_like (dummy_left )
161+
162+ q = deque ()
163+ x_img = x_img [:, - image_size :]
164+ for _ in range (image_size - x_img .shape [1 ]):
165+ q .append ((dummy_top , dummy_left ))
166+ for i in range (x_img .shape [1 ]):
167+ q .append (x_img [:, i ].chunk (4 , dim = - 1 )[:2 ])
168+ cache [cache_key ] = q
169+
170+ return self .fn (x , cache = cache , ** kwargs )
144171
145172# main transformer class
146173
@@ -277,6 +304,11 @@ def forward(self, x, **kwargs):
277304 return self .layers (x , rotary_pos_emb = self .pos_emb , ** kwargs )
278305
279306 def _get_static_mask (self , attn_type ):
307+ # In case of attn_type = "axial_{row,col}",
308+ # the sparse implementation is most efficient for training,
309+ # but the full attention with a static mask is most efficient for inference
310+ # since caching is implemented in this case.
311+
280312 img_seq_len = self .image_fmap_size ** 2
281313 text_len = self .seq_len + 1 - img_seq_len
282314
0 commit comments