forked from ivy-llc/ivy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(backend handler): throwing a warning when a sub-backend implemen…
…tation is available but is not being used. (ivy-llc#23489)
- Loading branch information
Showing
10 changed files
with
186 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
45 changes: 45 additions & 0 deletions
45
ivy_tests/test_ivy/test_misc/test_backend_utils/test_sub_backends_available.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pytest | ||
|
||
from importlib.util import find_spec | ||
import ivy | ||
|
||
|
||
def test_no_warning_when_no_sub_backend_implementation_available(): | ||
ivy.set_backend("numpy") | ||
q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]]) | ||
k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]]) | ||
v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]]) | ||
with pytest.warns(None) as record: | ||
ivy.scaled_dot_product_attention( | ||
q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True | ||
) | ||
assert len(record) == 0 | ||
|
||
|
||
@pytest.mark.skipif(find_spec("xformers") is None, reason="xformers is not installed") | ||
def test_sub_backend_implementation_available(): | ||
ivy.set_backend("torch") | ||
sub_backends = ivy.available_sub_backend_implementations( | ||
"scaled_dot_product_attention" | ||
) | ||
assert "xformers" in sub_backends | ||
|
||
|
||
def test_sub_backend_implementation_not_available(): | ||
ivy.set_backend("numpy") | ||
sub_backends = ivy.available_sub_backend_implementations( | ||
"scaled_dot_product_attention" | ||
) | ||
assert not sub_backends | ||
|
||
|
||
@pytest.mark.skipif(find_spec("xformers") is None, reason="xformers is not installed") | ||
def test_throw_warning_when_sub_backend_implementation_available_but_not_used(): | ||
ivy.set_backend("torch") | ||
q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]]) | ||
k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]]) | ||
v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]]) | ||
with pytest.warns(UserWarning): | ||
ivy.scaled_dot_product_attention( | ||
q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True | ||
) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters