Skip to content

Commit

Permalink
Update thunder/tests/test_cudnn_executor.py
Browse files Browse the repository at this point in the history
Co-authored-by: Ivan Yashchuk <[email protected]>
  • Loading branch information
2 people authored and Borda committed Mar 21, 2024
1 parent 635ac44 commit 84fb458
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req

# Non-contiguous input tensor case
nq = make(N, n_head, E, L).permute(0, 1, 3, 2)
nk = make(N, n_head, E, L).permute(0, 1, 3, 2)
nv = make(N, n_head, E, L).permute(0, 1, 3, 2)
nk = make(N, n_head, E, S).permute(0, 1, 3, 2)
nv = make(N, n_head, E, S).permute(0, 1, 3, 2)
yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False)

# Test the scale factor which was added in torch 2.1
Expand Down

0 comments on commit 84fb458

Please sign in to comment.