From 041f1144e8309e4f5bbf3769f35eb86859ed0ab9 Mon Sep 17 00:00:00 2001 From: Phil Tillet Date: Tue, 11 Jul 2023 09:28:20 -0700 Subject: [PATCH] [DOCS] fixed flash_attn causal argument in tutorial --- python/tutorials/06-fused-attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 2c7254de04f6..98804407b11b 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -409,7 +409,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) - fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) + fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) if mode == 'bwd': o = fn() do = torch.randn_like(o)