Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: throwing a warning when a sub-backend implementation is available but is not being used. #23489

Merged
merged 22 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a0b1010
fix: replaced deprecated logging.warn call with logging.warning in su…
Madjid-CH Sep 4, 2023
7c18158
feat(Ivy Framework):started working on throw_warning_when_sub_backend…
Sep 10, 2023
c750974
fix(tests): fixed failing test_sub_backend_implementation_available
Sep 10, 2023
a486df9
refactor(sub_backend_handler)
Sep 11, 2023
56ec27c
feat(func_wrapper) added efficient_implementation_available decorator.
Sep 11, 2023
df0e0fe
feat(backend handler): made the set_backend wrap the function with _w…
Sep 12, 2023
6c5bebe
fix(func_wrapper): made _warn_efficient_implementation_available rais…
Sep 12, 2023
ae8c14b
fix(backend handler): test_throw_warning_when_sub_backend_implementat…
Sep 12, 2023
401e895
refactor(tests): renamed test_sub_back.py to test_sub_backends_availa…
Sep 12, 2023
5f414e0
refactored sub_backend_handler.py and added sub backend verification …
Sep 13, 2023
df09039
refactored sub_backend_handler.py reintroduce _check_callable in sub_…
Sep 13, 2023
29ac2aa
inlined variable in available_sub_backend_implementations function
Sep 13, 2023
a84a466
removed _check_callable helper since it's not needed anymore after pa…
Sep 14, 2023
f3ba4a0
renamed _set_backend_as_ivy to _set_module_backend
Sep 18, 2023
153eb85
Merge remote-tracking branch 'upstream/main' into sub-backend-impleme…
Sep 18, 2023
dca8bc4
simplified finding available sub backends implementations by removing…
Sep 19, 2023
e47e9ba
moved backend tests to a new folder
Sep 19, 2023
c1954f0
updated tests to skip xformers is not installed
Sep 19, 2023
aa729b0
removed '_for' from 'available_sub_backend_implementations_for' funct…
Sep 20, 2023
724f5e9
Merging
Sep 20, 2023
265a472
restructured sub_backend_handler.py
Sep 20, 2023
98e44a6
fixed failing test_set_backend with paddle backend
Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ def dynamic_backend_as(value):
unset_sub_backend,
clear_sub_backends,
available_sub_backends,
available_sub_backend_implementations,
)


Expand Down
32 changes: 32 additions & 0 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,38 @@ def func(x):
return _handle_backend_invalid


def _handle_efficient_implementation_available(fn: Callable) -> Callable:
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
"""
Throws a warning whenever the function is called, indicating that an efficient
implementation is available. This should be used for functions which have an
efficient implementation in a sub backend.

Parameters
----------
args
The arguments to be passed to the function.

kwargs
The keyword arguments to be passed to the function.

Returns
-------
The return of the function.
"""
ivy.warn(
f"An efficient implementation of {fn.__name__} is available "
"in these sub backends: "
f"{ivy.available_sub_backend_implementations(fn.__name__)}.\n"
"use ivy.set_sub_backend('<sub_backend_name>') to use it."
)
return fn(*args, **kwargs)

_wrapper.efficient_implementation_available = True
return _wrapper


attribute_dict = {
"unsupported_dtypes",
"supported_dtypes",
Expand Down
21 changes: 15 additions & 6 deletions ivy/utils/backend/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ivy.utils import _importlib, verbosity

# local
from ivy.func_wrapper import _wrap_function
from ivy.func_wrapper import _wrap_function, _handle_efficient_implementation_available
from ivy.utils.backend.sub_backend_handler import _clear_current_sub_backends
from ivy.utils.exceptions import _handle_inplace_mode

Expand Down Expand Up @@ -243,7 +243,7 @@ def current_backend(*args, **kwargs):
return importlib.import_module(_backend_dict[implicit_backend])


def _set_backend_as_ivy(
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
def _modify_module_symbols(
original_dict, target, backend, invalid_dtypes=None, backend_str=None
):
invalid_dtypes = (
Expand All @@ -252,11 +252,13 @@ def _set_backend_as_ivy(
backend_str = backend.current_backend_str() if backend_str is None else backend_str
for k, v in original_dict.items():
compositional = k not in backend.__dict__
if k not in backend.__dict__:
if compositional:
if k in invalid_dtypes and k in target.__dict__:
del target.__dict__[k]
continue
v = _verify_efficient_implementations(v) if v else v
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
backend.__dict__[k] = v
v = _verify_efficient_implementations(v) if v else v
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
target.__dict__[k] = _wrap_function(
key=k, to_wrap=backend.__dict__[k], original=v, compositional=compositional
)
Expand All @@ -265,7 +267,7 @@ def _set_backend_as_ivy(
and "ivy.functional." in v.__name__
and os.path.join("{}", "__init__.py").format(backend_str) not in v.__file__
):
_set_backend_as_ivy(
_modify_module_symbols(
v.__dict__,
target.__dict__[k],
backend.__dict__[k],
Expand All @@ -274,6 +276,13 @@ def _set_backend_as_ivy(
)


def _verify_efficient_implementations(v):
if callable(v):
if ivy.available_sub_backend_implementations(v):
return _handle_efficient_implementation_available(v)
return v


def _handle_backend_specific_vars(target, backend):
if backend.current_backend_str() == "numpy":
target.set_default_device("cpu")
Expand Down Expand Up @@ -453,7 +462,7 @@ def set_backend(backend: str, dynamic: bool = False):
ivy.set_global_attr("RNG", ivy.functional.backends.jax.random.RNG)
backend_stack.append(backend)
set_backend_to_specific_version(backend)
_set_backend_as_ivy(ivy_original_dict, ivy, backend)
_modify_module_symbols(ivy_original_dict, ivy, backend)
# following snippet is required to update the ivy.functional namespace with
# backend-specific functions
for key, _ in ivy.__dict__.items():
Expand Down Expand Up @@ -640,7 +649,7 @@ def with_backend(backend: str, cached: bool = True):
set_backend_to_specific_version(backend_module)
# We know for sure that the backend stack is empty
# no need to do backend unsetting
ivy_pack.utils.backend.handler._set_backend_as_ivy(
ivy_pack.utils.backend.handler._modify_module_symbols(
ivy_pack.__dict__.copy(), ivy_pack, backend_module
)
# TODO use a refactored code from ivy.set_backend
Expand Down
151 changes: 117 additions & 34 deletions ivy/utils/backend/sub_backend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,70 @@
from types import ModuleType, FunctionType
import logging
import importlib
from typing import Callable

import ivy
from ivy.func_wrapper import _wrap_function
from ivy.utils.exceptions import IvyException


_backends_subpackage_path = "ivy.functional.backends"
_sub_backend_dict = dict()
_backend_to_sub_backends_dict = dict()

# dynamic sub_backend detection
for backend in os.listdir(
os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0], # type: ignore
_backends_subpackage_path.replace(".", os.path.sep),
)
):
if not backend[0].isalpha():
continue

sub_backends_dir = os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0],
_backends_subpackage_path.replace(".", os.path.sep),
backend,
"sub_backends",
)
for sub_backend in os.listdir(sub_backends_dir):
if not sub_backend[0].isalpha():
_sub_backend_dict: dict[str, str] = dict()
_backend_to_sub_backends_dict: dict[str, list] = dict()
_available_sub_backends_implementations_dict: dict[str, dict[str, list]] = dict()


def _detect_sub_backends_dynamically():
for backend in os.listdir(
os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0], # type: ignore
_backends_subpackage_path.replace(".", os.path.sep),
)
):
if not backend[0].isalpha():
continue
_sub_backend_dict[sub_backend] = (
f"{_backends_subpackage_path}.{backend}.sub_backends.{sub_backend}"

sub_backends_dir = os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0],
_backends_subpackage_path.replace(".", os.path.sep),
backend,
"sub_backends",
)
try:
_backend_to_sub_backends_dict[backend].append(sub_backend)
except KeyError:
_backend_to_sub_backends_dict[backend] = [sub_backend]
for sub_backend in os.listdir(sub_backends_dir):
if not sub_backend[0].isalpha():
continue
_sub_backend_dict[sub_backend] = (
f"{_backends_subpackage_path}.{backend}.sub_backends.{sub_backend}"
)
try:
_backend_to_sub_backends_dict[backend].append(sub_backend)
except KeyError:
_backend_to_sub_backends_dict[backend] = [sub_backend]


_detect_sub_backends_dynamically()

_all_sub_backends = []

for v in _backend_to_sub_backends_dict.values():
_all_sub_backends.extend(v)
def _get_all_sub_backends():
result = []
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
for v in _backend_to_sub_backends_dict.values():
result.extend(v)
return result


_all_sub_backends = _get_all_sub_backends()


original_backend_dict = None


def set_sub_backend(sub_backend_str: str):
if ivy.backend == "":
logging.warn("You must set a backend first")
logging.warning("You must set a backend first")
return

if ivy.current_backend_str() not in _backend_to_sub_backends_dict.keys():
logging.warn(
logging.warning(
f"backend {ivy.current_backend_str()} does not have any"
" supported sub_backends"
)
Expand All @@ -68,7 +78,7 @@ def set_sub_backend(sub_backend_str: str):
)

if sub_backend_str not in _backend_to_sub_backends_dict[ivy.current_backend_str()]:
logging.warn(
logging.warning(
f"{ivy.current_backend_str()} does not support"
f" {sub_backend_str} as a sub_backend"
)
Expand All @@ -87,7 +97,7 @@ def set_sub_backend(sub_backend_str: str):
ivy.current_backend().sub_backends._current_sub_backends.append(sub_backend_str)


# this is very similiar to _set_backend_as_ivy in handler.py, with a minor change
# this is very similar to _set_backend_as_ivy in handler.py, with a minor change
def _set_sub_backend_as_ivy(
original: dict, target: ModuleType, sub_backend: ModuleType
):
Expand Down Expand Up @@ -200,3 +210,76 @@ def find_available_sub_backends(sub_backends_loc):
available_sub_backends.append(sub_backend)

return available_sub_backends


def _find_available_sub_backend_implementations(sub_backends):
result = dict()
for sub in sub_backends:
sub_backend = ivy.utils.dynamic_import.import_module(_sub_backend_dict[sub])
for k, v in sub_backend.__dict__.items():
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(v, Callable) and not k.startswith("__"):
result[k] = result.get(k, []) + [sub]

return result


def available_sub_backend_implementations(obj: str) -> list:
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
"""
Return whether a sub-backend implementation is available for `obj`.

Parameters
----------
obj : str
the object for which to check if a sub-backend implementation is available.

Returns
-------
ret : list
a list of sub-backend implementations available for `obj`.

Examples
--------
>>> import ivy
>>> ivy.set_backend('torch')
>>> ivy.available_sub_backend_implementations("scaled_dot_product_attention")
['xformers']
>>> ivy.set_backend('numpy')
>>> ivy.available_sub_backend_implementations("scaled_dot_product_attention")
[]
"""
obj = _check_callable(obj)
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
sub_backends = ivy.current_backend().available_sub_backends()
result = []
if not sub_backends:
return result
if not _sub_backends_implementations_already_verified():
_verify_sub_backends_implementations(sub_backends)
return _available_implementations_for(obj)


def _check_callable(obj):
Madjid-CH marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(obj, str):
obj = getattr(ivy, obj)
if not callable(obj):
raise TypeError(
"The argument `obj` must be a callable or a string representing a callable"
)
return obj


def _sub_backends_implementations_already_verified():
return (
ivy.current_backend_str() in _available_sub_backends_implementations_dict.keys()
)


def _verify_sub_backends_implementations(sub_backends):
_available_sub_backends_implementations_dict[ivy.current_backend_str()] = (
_find_available_sub_backend_implementations(sub_backends)
)


def _available_implementations_for(obj):
return _available_sub_backends_implementations_dict[ivy.current_backend_str()].get(
obj.__name__, []
)
6 changes: 6 additions & 0 deletions ivy_tests/test_ivy/test_misc/test_func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,9 @@ def test_views(array_to_update, backend_fw):
assert np.allclose(d, d_copy + 1)
assert np.allclose(e[0], e_copy + 1)
ivy.previous_backend()


def test_warn_efficient_implementations():
to_test = ivy.func_wrapper._handle_efficient_implementation_available(ivy.array)
with pytest.warns(UserWarning):
to_test(1)
42 changes: 42 additions & 0 deletions ivy_tests/test_ivy/test_misc/test_sub_backends_available.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

import ivy


def test_no_warning_when_no_sub_backend_implementation_available():
ivy.set_backend("numpy")
q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]])
k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]])
v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]])
with pytest.warns(None) as record:
ivy.scaled_dot_product_attention(
q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True
)
assert len(record) == 0


def test_sub_backend_implementation_available():
ivy.set_backend("torch")
sub_backends = ivy.available_sub_backend_implementations(
"scaled_dot_product_attention"
)
assert sub_backends == ["xformers"]


def test_sub_backend_implementation_not_available():
ivy.set_backend("numpy")
sub_backends = ivy.available_sub_backend_implementations(
"scaled_dot_product_attention"
)
assert not sub_backends


def test_throw_warning_when_sub_backend_implementation_available_but_not_used():
ivy.set_backend("torch")
q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]])
k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]])
v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]])
with pytest.warns(UserWarning):
ivy.scaled_dot_product_attention(
q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True
)
Loading