Skip to content

Commit

Permalink
bugfix: fix unittests not yield error in AOT mode (#657)
Browse files Browse the repository at this point in the history
The unittests in AOT mode failed since
#629 because we didn't
use return instead of yield in warmup functions, this PR fixes the
issue.
  • Loading branch information
yzh119 authored Dec 13, 2024
1 parent 4c15777 commit 6dfc9d8
Show file tree
Hide file tree
Showing 10 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/test_alibi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_batch_decode_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_batch_prefill_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_block_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logits_cap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_non_contiguous_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_non_contiguous_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_prefill_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor_cores_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
return
yield
try:
flashinfer.jit.parallel_load_modules(
jit_decode_attention_func_args(
Expand Down

0 comments on commit 6dfc9d8

Please sign in to comment.