Skip to content

Commit

Permalink
force is_causal for ndc
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Oct 30, 2024
1 parent d3b0681 commit 344cda2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
attn_weights, values
) # (bs, heads, slen, head_dim)
else:
is_causal = attention_mask is None and batch_seq_len == 1
# Use the builtin attention mask when not decomposed
is_causal = True
attention_mask = None
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=xq, # [bs, ..., sl, dim]
key=keys, # [bs, ..., sl, dim]
Expand Down

0 comments on commit 344cda2

Please sign in to comment.