diff --git a/ivy/__init__.py b/ivy/__init__.py index d9afb1632c9ba..a56613a8543fc 100644 --- a/ivy/__init__.py +++ b/ivy/__init__.py @@ -1223,6 +1223,7 @@ def dynamic_backend_as(value): unset_sub_backend, clear_sub_backends, available_sub_backends, + available_sub_backend_implementations, ) diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py index f3d6a066dd813..e5e9f3f555056 100644 --- a/ivy/func_wrapper.py +++ b/ivy/func_wrapper.py @@ -1594,6 +1594,38 @@ def func(x): return _handle_backend_invalid +def _handle_efficient_implementation_available(fn: Callable) -> Callable: + @functools.wraps(fn) + def _wrapper(*args, **kwargs): + """ + Throws a warning whenever the function is called, indicating that an efficient + implementation is available. This should be used for functions which have an + efficient implementation in a sub backend. + + Parameters + ---------- + args + The arguments to be passed to the function. + + kwargs + The keyword arguments to be passed to the function. + + Returns + ------- + The return of the function. + """ + ivy.warn( + f"An efficient implementation of {fn.__name__} is available " + "in these sub backends: " + f"{ivy.available_sub_backend_implementations(fn.__name__)}.\n" + "use ivy.set_sub_backend('') to use it." + ) + return fn(*args, **kwargs) + + _wrapper.efficient_implementation_available = True + return _wrapper + + attribute_dict = { "unsupported_dtypes", "supported_dtypes", diff --git a/ivy/functional/backends/paddle/__init__.py b/ivy/functional/backends/paddle/__init__.py index ae6274fb0b9ac..bd0f7259b687c 100644 --- a/ivy/functional/backends/paddle/__init__.py +++ b/ivy/functional/backends/paddle/__init__.py @@ -297,3 +297,8 @@ def closest_valid_dtype(type=None, /, as_native=False): from .experimental import * from . import control_flow_ops from .control_flow_ops import * + +# sub-backends + +from . import sub_backends +from .sub_backends import * diff --git a/ivy/utils/backend/handler.py b/ivy/utils/backend/handler.py index 937c48f3316fe..3648455dabdd0 100644 --- a/ivy/utils/backend/handler.py +++ b/ivy/utils/backend/handler.py @@ -10,7 +10,7 @@ from ivy.utils import _importlib, verbosity # local -from ivy.func_wrapper import _wrap_function +from ivy.func_wrapper import _wrap_function, _handle_efficient_implementation_available from ivy.utils.backend.sub_backend_handler import ( _clear_current_sub_backends, fn_name_from_version_specific_fn_name, @@ -199,7 +199,7 @@ def current_backend(*args, **kwargs): return importlib.import_module(_backend_dict[implicit_backend]) -def _set_backend_as_ivy( +def _set_module_backend( original_dict, target, backend, invalid_dtypes=None, backend_str=None ): invalid_dtypes = ( @@ -207,8 +207,9 @@ def _set_backend_as_ivy( ) backend_str = backend.current_backend_str() if backend_str is None else backend_str for k, v in original_dict.items(): + v = _wrap_if_got_efficient_implementations(v) if v else v compositional = k not in backend.__dict__ - if k not in backend.__dict__: + if compositional: if k in invalid_dtypes and k in target.__dict__: del target.__dict__[k] continue @@ -221,7 +222,7 @@ def _set_backend_as_ivy( and "ivy.functional." in v.__name__ and os.path.join("{}", "__init__.py").format(backend_str) not in v.__file__ ): - _set_backend_as_ivy( + _set_module_backend( v.__dict__, target.__dict__[k], backend.__dict__[k], @@ -230,6 +231,13 @@ def _set_backend_as_ivy( ) +def _wrap_if_got_efficient_implementations(v): + if callable(v): + if ivy.available_sub_backend_implementations(v.__name__): + return _handle_efficient_implementation_available(v) + return v + + def _handle_backend_specific_vars(target, backend): if backend.current_backend_str() == "numpy": target.set_default_device("cpu") @@ -409,7 +417,7 @@ def set_backend(backend: str, dynamic: bool = False): ivy.set_global_attr("RNG", ivy.functional.backends.jax.random.RNG) backend_stack.append(backend) set_backend_to_specific_version(backend) - _set_backend_as_ivy(ivy_original_dict, ivy, backend) + _set_module_backend(ivy_original_dict, ivy, backend) # following snippet is required to update the ivy.functional namespace with # backend-specific functions for key, _ in ivy.__dict__.items(): @@ -594,7 +602,7 @@ def with_backend(backend: str, cached: bool = True): set_backend_to_specific_version(backend_module) # We know for sure that the backend stack is empty # no need to do backend unsetting - ivy_pack.utils.backend.handler._set_backend_as_ivy( + ivy_pack.utils.backend.handler._set_module_backend( ivy_pack.__dict__.copy(), ivy_pack, backend_module ) # TODO use a refactored code from ivy.set_backend diff --git a/ivy/utils/backend/sub_backend_handler.py b/ivy/utils/backend/sub_backend_handler.py index da8b7d622ce23..3bfbb75ec36c5 100644 --- a/ivy/utils/backend/sub_backend_handler.py +++ b/ivy/utils/backend/sub_backend_handler.py @@ -10,8 +10,52 @@ _backends_subpackage_path = "ivy.functional.backends" -_sub_backend_dict = {} -_backend_to_sub_backends_dict = {} +_sub_backend_dict: dict[str, str] = {} +_backend_to_sub_backends_dict: dict[str, list] = {} + + +def _detect_sub_backends_dynamically(): + for backend in os.listdir( + os.path.join( + ivy.__path__[0].rpartition(os.path.sep)[0], # type: ignore + _backends_subpackage_path.replace(".", os.path.sep), + ) + ): + if not backend[0].isalpha(): + continue + + sub_backends_dir = os.path.join( + ivy.__path__[0].rpartition(os.path.sep)[0], + _backends_subpackage_path.replace(".", os.path.sep), + backend, + "sub_backends", + ) + for sub_backend in os.listdir(sub_backends_dir): + if not sub_backend[0].isalpha(): + continue + _sub_backend_dict[sub_backend] = ( + f"{_backends_subpackage_path}.{backend}.sub_backends.{sub_backend}" + ) + try: + _backend_to_sub_backends_dict[backend].append(sub_backend) + except KeyError: + _backend_to_sub_backends_dict[backend] = [sub_backend] + + +_detect_sub_backends_dynamically() + + +def _get_all_sub_backends(): + sub_backends = [] + for v in _backend_to_sub_backends_dict.values(): + sub_backends.extend(v) + return sub_backends + + +_all_sub_backends = _get_all_sub_backends() + + +original_backend_dict = None # version specific sub-backend setting @@ -122,43 +166,6 @@ def fn_name_from_version_specific_fn_name_sub_backend( return name[: v_occurences[0]] -# dynamic sub_backend detection -for backend in os.listdir( - os.path.join( - ivy.__path__[0].rpartition(os.path.sep)[0], # type: ignore - _backends_subpackage_path.replace(".", os.path.sep), - ) -): - if not backend[0].isalpha(): - continue - - sub_backends_dir = os.path.join( - ivy.__path__[0].rpartition(os.path.sep)[0], - _backends_subpackage_path.replace(".", os.path.sep), - backend, - "sub_backends", - ) - for sub_backend in os.listdir(sub_backends_dir): - if not sub_backend[0].isalpha(): - continue - _sub_backend_dict[sub_backend] = ( - f"{_backends_subpackage_path}.{backend}.sub_backends.{sub_backend}" - ) - try: - _backend_to_sub_backends_dict[backend].append(sub_backend) - except KeyError: - _backend_to_sub_backends_dict[backend] = [sub_backend] - - -_all_sub_backends = [] - -for v in _backend_to_sub_backends_dict.values(): - _all_sub_backends.extend(v) - - -original_backend_dict = None - - def set_sub_backend(sub_backend_str: str): if ivy.backend == "": logging.warning("You must set a backend first") @@ -198,7 +205,7 @@ def set_sub_backend(sub_backend_str: str): ivy.current_backend().sub_backends._current_sub_backends.append(sub_backend_str) -# this is very similiar to _set_backend_as_ivy in handler.py, with a minor change +# this is very similar to _set_backend_as_ivy in handler.py, with a minor change def _set_sub_backend_as_ivy( original: dict, target: ModuleType, sub_backend: ModuleType ): @@ -311,3 +318,39 @@ def find_available_sub_backends(sub_backends_loc): available_sub_backends.append(sub_backend) return available_sub_backends + + +def available_sub_backend_implementations(fn_name: str) -> list: + """ + Return whether a sub-backend implementation is available for `fn_name`. + + Parameters + ---------- + fn_name : str + the object for which to check if a sub-backend implementation is available. + + Returns + ------- + ret : list + a list of sub-backend implementations available for `fn_name`. + + Examples + -------- + >>> import ivy + >>> ivy.set_backend('torch') + >>> ivy.available_sub_backend_implementations("scaled_dot_product_attention") + ['xformers'] + >>> ivy.set_backend('numpy') + >>> ivy.available_sub_backend_implementations("scaled_dot_product_attention") + [] + """ + sub_backends = ivy.current_backend().available_sub_backends() + implementations = [] + for sub in sub_backends: + try: + sub_backend = ivy.utils.dynamic_import.import_module(_sub_backend_dict[sub]) + except ModuleNotFoundError: + continue + if fn_name in sub_backend.__dict__: + implementations.append(sub) + return implementations diff --git a/ivy_tests/test_ivy/test_misc/test_backend_utils/__init__.py b/ivy_tests/test_ivy/test_misc/test_backend_utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/ivy_tests/test_ivy/test_misc/test_backend_handler.py b/ivy_tests/test_ivy/test_misc/test_backend_utils/test_backend_handler.py similarity index 100% rename from ivy_tests/test_ivy/test_misc/test_backend_handler.py rename to ivy_tests/test_ivy/test_misc/test_backend_utils/test_backend_handler.py diff --git a/ivy_tests/test_ivy/test_misc/test_backend_utils/test_sub_backends_available.py b/ivy_tests/test_ivy/test_misc/test_backend_utils/test_sub_backends_available.py new file mode 100644 index 0000000000000..0a2efb2259a16 --- /dev/null +++ b/ivy_tests/test_ivy/test_misc/test_backend_utils/test_sub_backends_available.py @@ -0,0 +1,45 @@ +import pytest + +from importlib.util import find_spec +import ivy + + +def test_no_warning_when_no_sub_backend_implementation_available(): + ivy.set_backend("numpy") + q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]]) + k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]]) + v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]]) + with pytest.warns(None) as record: + ivy.scaled_dot_product_attention( + q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True + ) + assert len(record) == 0 + + +@pytest.mark.skipif(find_spec("xformers") is None, reason="xformers is not installed") +def test_sub_backend_implementation_available(): + ivy.set_backend("torch") + sub_backends = ivy.available_sub_backend_implementations( + "scaled_dot_product_attention" + ) + assert "xformers" in sub_backends + + +def test_sub_backend_implementation_not_available(): + ivy.set_backend("numpy") + sub_backends = ivy.available_sub_backend_implementations( + "scaled_dot_product_attention" + ) + assert not sub_backends + + +@pytest.mark.skipif(find_spec("xformers") is None, reason="xformers is not installed") +def test_throw_warning_when_sub_backend_implementation_available_but_not_used(): + ivy.set_backend("torch") + q = ivy.array([[[0.2, 1.0], [2.2, 3.0], [4.4, 5.6]]]) + k = ivy.array([[[0.6, 1.5], [2.4, 3.3], [4.2, 5.1]]]) + v = ivy.array([[[0.4, 1.3], [2.2, 3.1], [4.3, 5.3]]]) + with pytest.warns(UserWarning): + ivy.scaled_dot_product_attention( + q, k, v, scale=1, dropout_p=0.1, is_causal=True, training=True + ) diff --git a/ivy_tests/test_ivy/test_misc/test_with_backend.py b/ivy_tests/test_ivy/test_misc/test_backend_utils/test_with_backend.py similarity index 100% rename from ivy_tests/test_ivy/test_misc/test_with_backend.py rename to ivy_tests/test_ivy/test_misc/test_backend_utils/test_with_backend.py diff --git a/ivy_tests/test_ivy/test_misc/test_func_wrapper.py b/ivy_tests/test_ivy/test_misc/test_func_wrapper.py index de3c5a73c9604..774b1a7e32cb9 100644 --- a/ivy_tests/test_ivy/test_misc/test_func_wrapper.py +++ b/ivy_tests/test_ivy/test_misc/test_func_wrapper.py @@ -209,3 +209,9 @@ def test_views(array_to_update, backend_fw): assert np.allclose(d, d_copy + 1) assert np.allclose(e[0], e_copy + 1) ivy.previous_backend() + + +def test_warn_efficient_implementations(): + to_test = ivy.func_wrapper._handle_efficient_implementation_available(ivy.array) + with pytest.warns(UserWarning): + to_test(1)