@@ -65,19 +65,17 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
6565 def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
6666 b , n , _ , h , device = * x .shape , self .heads , x .device
6767 softmax = torch .softmax if not self .stable else stable_softmax
68- using_cache = exists (cache ) and cache_key in cache
68+ offset = cache . get ( 'offset' , 0 ) if exists (cache ) else 0
6969
7070 qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
7171 q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
7272
7373 if exists (rotary_pos_emb ):
74- if using_cache :
75- rotary_pos_emb = rotary_pos_emb [..., n - 1 :, :] # FIXME: Fix rotary index here
76- q , k , v = apply_pos_emb (rotary_pos_emb , (q , k , v ))
74+ q , k , v = apply_pos_emb (rotary_pos_emb [..., offset :, :], (q , k , v ))
7775
7876 q = q * self .scale
7977
80- if using_cache :
78+ if offset > 0 :
8179 k_top , v_top = cache [cache_key ]
8280 k = torch .cat ([k_top , k ], dim = - 2 )
8381 v = torch .cat ([v_top , v ], dim = - 2 )
@@ -92,7 +90,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9290 dots .masked_fill_ (~ mask , mask_value )
9391 del mask
9492
95- if self .causal and not using_cache : # causality is naturally enforced if we run the cached inference
93+ if self .causal and offset == 0 : # causality is naturally enforced for the cached inference
9694 i , j = dots .shape [- 2 :]
9795 mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
9896 dots .masked_fill_ (mask , mask_value )
0 commit comments