Skip to content

Commit

Permalink
feat(backend handler): throwing a warning when a sub-backend implemen…
Browse files Browse the repository at this point in the history
…tation is available but is not being used. (ivy-llc#23489)
  • Loading branch information
Madjid-CH authored and iababio committed Sep 27, 2023
1 parent 3f3fb2f commit 5816ab8
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 46 deletions.
1 change: 1 addition & 0 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,7 @@ def dynamic_backend_as(value):
unset_sub_backend,
clear_sub_backends,
available_sub_backends,
available_sub_backend_implementations,
)


Expand Down
32 changes: 32 additions & 0 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,38 @@ def func(x):
return _handle_backend_invalid


def _handle_efficient_implementation_available(fn: Callable) -> Callable:
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
"""
Throws a warning whenever the function is called, indicating that an efficient
implementation is available. This should be used for functions which have an
efficient implementation in a sub backend.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function.
"""
ivy.warn(
f"An efficient implementation of {fn.__name__} is available "
"in these sub backends: "
f"{ivy.available_sub_backend_implementations(fn.__name__)}.\n"
"use ivy.set_sub_backend('<sub_backend_name>') to use it."
)
return fn(*args, **kwargs)

_wrapper.efficient_implementation_available = True
return _wrapper


attribute_dict = {
"unsupported_dtypes",
"supported_dtypes",
Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/backends/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,8 @@ def closest_valid_dtype(type=None, /, as_native=False):
from .experimental import *
from . import control_flow_ops
from .control_flow_ops import *

# sub-backends

from . import sub_backends
from .sub_backends import *
20 changes: 14 additions & 6 deletions ivy/utils/backend/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ivy.utils import _importlib, verbosity

# local
from ivy.func_wrapper import _wrap_function
from ivy.func_wrapper import _wrap_function, _handle_efficient_implementation_available
from ivy.utils.backend.sub_backend_handler import (
_clear_current_sub_backends,
fn_name_from_version_specific_fn_name,
Expand Down Expand Up @@ -199,16 +199,17 @@ def current_backend(*args, **kwargs):
return importlib.import_module(_backend_dict[implicit_backend])


def _set_backend_as_ivy(
def _set_module_backend(
original_dict, target, backend, invalid_dtypes=None, backend_str=None
):
invalid_dtypes = (
backend.invalid_dtypes if invalid_dtypes is None else invalid_dtypes
)
backend_str = backend.current_backend_str() if backend_str is None else backend_str
for k, v in original_dict.items():
v = _wrap_if_got_efficient_implementations(v) if v else v
compositional = k not in backend.__dict__
if k not in backend.__dict__:
if compositional:
if k in invalid_dtypes and k in target.__dict__:
del target.__dict__[k]
continue
Expand All @@ -221,7 +222,7 @@ def _set_backend_as_ivy(
and "ivy.functional." in v.__name__
and os.path.join("{}", "__init__.py").format(backend_str) not in v.__file__
):
_set_backend_as_ivy(
_set_module_backend(
v.__dict__,
target.__dict__[k],
backend.__dict__[k],
Expand All @@ -230,6 +231,13 @@ def _set_backend_as_ivy(
)


def _wrap_if_got_efficient_implementations(v):
if callable(v):
if ivy.available_sub_backend_implementations(v.__name__):
return _handle_efficient_implementation_available(v)
return v


def _handle_backend_specific_vars(target, backend):
if backend.current_backend_str() == "numpy":
target.set_default_device("cpu")
Expand Down Expand Up @@ -409,7 +417,7 @@ def set_backend(backend: str, dynamic: bool = False):
ivy.set_global_attr("RNG", ivy.functional.backends.jax.random.RNG)
backend_stack.append(backend)
set_backend_to_specific_version(backend)
_set_backend_as_ivy(ivy_original_dict, ivy, backend)
_set_module_backend(ivy_original_dict, ivy, backend)
# following snippet is required to update the ivy.functional namespace with
# backend-specific functions
for key, _ in ivy.__dict__.items():
Expand Down Expand Up @@ -594,7 +602,7 @@ def with_backend(backend: str, cached: bool = True):
set_backend_to_specific_version(backend_module)
# We know for sure that the backend stack is empty
# no need to do backend unsetting
ivy_pack.utils.backend.handler._set_backend_as_ivy(
ivy_pack.utils.backend.handler._set_module_backend(
ivy_pack.__dict__.copy(), ivy_pack, backend_module
)
# TODO use a refactored code from ivy.set_backend
Expand Down
123 changes: 83 additions & 40 deletions ivy/utils/backend/sub_backend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,52 @@


_backends_subpackage_path = "ivy.functional.backends"
_sub_backend_dict = {}
_backend_to_sub_backends_dict = {}
_sub_backend_dict: dict[str, str] = {}
_backend_to_sub_backends_dict: dict[str, list] = {}


def _detect_sub_backends_dynamically():
for backend in os.listdir(
os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0], # type: ignore
_backends_subpackage_path.replace(".", os.path.sep),
)
):
if not backend[0].isalpha():
continue

sub_backends_dir = os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0],
_backends_subpackage_path.replace(".", os.path.sep),
backend,
"sub_backends",
)
for sub_backend in os.listdir(sub_backends_dir):
if not sub_backend[0].isalpha():
continue
_sub_backend_dict[sub_backend] = (
f"{_backends_subpackage_path}.{backend}.sub_backends.{sub_backend}"
)
try:
_backend_to_sub_backends_dict[backend].append(sub_backend)
except KeyError:
_backend_to_sub_backends_dict[backend] = [sub_backend]


_detect_sub_backends_dynamically()


def _get_all_sub_backends():
sub_backends = []
for v in _backend_to_sub_backends_dict.values():
sub_backends.extend(v)
return sub_backends


_all_sub_backends = _get_all_sub_backends()


original_backend_dict = None


# version specific sub-backend setting
Expand Down Expand Up @@ -122,43 +166,6 @@ def fn_name_from_version_specific_fn_name_sub_backend(
return name[: v_occurences[0]]


# dynamic sub_backend detection
for backend in os.listdir(
os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0], # type: ignore
_backends_subpackage_path.replace(".", os.path.sep),
)
):
if not backend[0].isalpha():
continue

sub_backends_dir = os.path.join(
ivy.__path__[0].rpartition(os.path.sep)[0],
_backends_subpackage_path.replace(".", os.path.sep),
backend,
"sub_backends",
)
for sub_backend in os.listdir(sub_backends_dir):
if not sub_backend[0].isalpha():
continue
_sub_backend_dict[sub_backend] = (
f"{_backends_subpackage_path}.{backend}.sub_backends.{sub_backend}"
)
try:
_backend_to_sub_backends_dict[backend].append(sub_backend)
except KeyError:
_backend_to_sub_backends_dict[backend] = [sub_backend]


_all_sub_backends = []

for v in _backend_to_sub_backends_dict.values():
_all_sub_backends.extend(v)


original_backend_dict = None


def set_sub_backend(sub_backend_str: str):
if ivy.backend == "":
logging.warning("You must set a backend first")
Expand Down Expand Up @@ -198,7 +205,7 @@ def set_sub_backend(sub_backend_str: str):
ivy.current_backend().sub_backends._current_sub_backends.append(sub_backend_str)


# this is very similiar to _set_backend_as_ivy in handler.py, with a minor change
# this is very similar to _set_backend_as_ivy in handler.py, with a minor change
def _set_sub_backend_as_ivy(
original: dict, target: ModuleType, sub_backend: ModuleType
):
Expand Down Expand Up @@ -311,3 +318,39 @@ def find_available_sub_backends(sub_backends_loc):
available_sub_backends.append(sub_backend)

return available_sub_backends


def available_sub_backend_implementations(fn_name: str) -> list:
"""
Return whether a sub-backend implementation is available for `fn_name`.
Parameters
----------
fn_name : str
the object for which to check if a sub-backend implementation is available.
Returns
-------
ret : list
a list of sub-backend implementations available for `fn_name`.
Examples
--------
>>> import ivy
>>> ivy.set_backend('torch')
>>> ivy.available_sub_backend_implementations("scaled_dot_product_attention")
['xformers']
>>> ivy.set_backend('numpy')
>>> ivy.available_sub_backend_implementations("scaled_dot_product_attention")
[]
"""
sub_backends = ivy.current_backend().available_sub_backends()
implementations = []
for sub in sub_backends:
try:
sub_backend = ivy.utils.dynamic_import.import_module(_sub_backend_dict[sub])
except ModuleNotFoundError:
continue
if fn_name in sub_backend.__dict__:
implementations.append(sub)
return implementations
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from importlib.util import find_spec
import ivy


def test_no_warning_when_no_sub_backend_implementation_available():
ivy.set_backend("numpy")
q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]])
k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]])
v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]])
with pytest.warns(None) as record:
ivy.scaled_dot_product_attention(
q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True
)
assert len(record) == 0


@pytest.mark.skipif(find_spec("xformers") is None, reason="xformers is not installed")
def test_sub_backend_implementation_available():
ivy.set_backend("torch")
sub_backends = ivy.available_sub_backend_implementations(
"scaled_dot_product_attention"
)
assert "xformers" in sub_backends


def test_sub_backend_implementation_not_available():
ivy.set_backend("numpy")
sub_backends = ivy.available_sub_backend_implementations(
"scaled_dot_product_attention"
)
assert not sub_backends


@pytest.mark.skipif(find_spec("xformers") is None, reason="xformers is not installed")
def test_throw_warning_when_sub_backend_implementation_available_but_not_used():
ivy.set_backend("torch")
q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]])
k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]])
v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]])
with pytest.warns(UserWarning):
ivy.scaled_dot_product_attention(
q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True
)
6 changes: 6 additions & 0 deletions ivy_tests/test_ivy/test_misc/test_func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,9 @@ def test_views(array_to_update, backend_fw):
assert np.allclose(d, d_copy + 1)
assert np.allclose(e[0], e_copy + 1)
ivy.previous_backend()


def test_warn_efficient_implementations():
to_test = ivy.func_wrapper._handle_efficient_implementation_available(ivy.array)
with pytest.warns(UserWarning):
to_test(1)

0 comments on commit 5816ab8

Please sign in to comment.