Skip to content

Commit

Permalink
Fix test_cse_rematerialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored and Borda committed Mar 21, 2024
1 parent 180a1ff commit 6d0ca5e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,14 @@ def test_cse_rematerialization(executor, device, _):

fw_trace = thunder.last_traces(compiled_func)[-1]
fusion_bsyms = tuple(filter(lambda a: a.sym.is_fusion, fw_trace.bound_symbols))
assert len(fusion_bsyms) == 13
# fusion groups 1 and 7 correspond with the apply_rotary_emb function
assert len(fusion_bsyms) == 11
# fusion groups 1 and 6 correspond with the apply_rotary_emb function
# Nvfuser with recomputation should use precomputed cos and sin values.
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[7].args)
assert len(fusion_bsyms[1].args) == len(fusion_bsyms[6].args)
assert fusion_bsyms[1].args[0].name == "freqs_cos"
assert fusion_bsyms[1].args[1].name == "freqs_sin"
assert fusion_bsyms[7].args[0].name == "freqs_cos"
assert fusion_bsyms[7].args[1].name == "freqs_sin"
assert fusion_bsyms[6].args[0].name == "freqs_cos"
assert fusion_bsyms[6].args[1].name == "freqs_sin"


# Tests that two separated nvFuser regions can be merged when they don't depend
Expand Down

0 comments on commit 6d0ca5e

Please sign in to comment.