Skip to content

Commit 6d3025d

Browse files
sxufacebook-github-bot
authored andcommitted
Create KV cache input tensor only if cache len > 0 for that layer (#15042)
Summary: The MHA branch has this logic already, add it to the other branch. Differential Revision: D84471388
1 parent d00279d commit 6d3025d

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/models/llama/static_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def __init__(
297297
dtype=dtype,
298298
)
299299
for layer_id in range(config.n_layers)
300+
if cache_lens[layer_id] > 0
300301
}
301302
self.v_caches = {
302303
StaticKVCache.calculate_cache_key(layer_id, 0): torch.zeros(
@@ -307,6 +308,7 @@ def __init__(
307308
dtype=dtype,
308309
)
309310
for layer_id in range(config.n_layers)
311+
if cache_lens[layer_id] > 0
310312
}
311313

312314
self.config = config

0 commit comments

Comments
 (0)