99
1010from dalle_pytorch .reversible import ReversibleSequence , SequentialSequence
1111from dalle_pytorch .attention import Attention , SparseAttention , SparseConvCausalAttention , SparseAxialCausalAttention
12- from dalle_pytorch .cache import FixCacheKey
1312
1413from rotary_embedding_torch import RotaryEmbedding , broadcat
1514from g_mlp_pytorch import gMLPBlock
@@ -36,6 +35,15 @@ def forward(self, x):
3635 maxes = x .amax (dim = self .dim , keepdim = True )
3736 return x / maxes
3837
38+ class CachedAs (nn .Module ):
39+ def __init__ (self , cache_key , fn ):
40+ super ().__init__ ()
41+ self .cache_key = cache_key
42+ self .fn = fn
43+
44+ def forward (self , x , * , cache = None , ** kwargs ):
45+ return self .fn (x , cache = cache , cache_key = self .cache_key , ** kwargs )
46+
3947# https://arxiv.org/abs/2103.17239
4048class LayerScale (nn .Module ):
4149 def __init__ (self , dim , depth , fn ):
@@ -200,7 +208,7 @@ def __init__(
200208 ff = FeedForward (dim , mult = ff_mult , dropout = ff_dropout )
201209 shared_ff_layers [ff_id ] = ff
202210
203- attn = FixCacheKey (f'attn_{ ind } ' , attn )
211+ attn = CachedAs (f'attn_{ ind } ' , attn )
204212
205213 if shift_tokens :
206214 attn , ff = map (lambda t : PreShiftToken (t , image_size = image_fmap_size , seq_len = seq_len ), (attn , ff ))
0 commit comments