Skip to content

Commit

Permalink
[DOCS] fixed flash_attn causal argument in tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Jul 11, 2023
1 parent bbc1ad1 commit 041f114
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
Expand Down

0 comments on commit 041f114

Please sign in to comment.