diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index bd1c762e1168a..33ef586c05b28 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -121,7 +121,7 @@ def prod( return x.type(dtype) if axis is None: return torch.prod(input=x, dtype=dtype) - if isinstance(axis, tuple) or isinstance(axis, list): + if isinstance(axis, (tuple, list)): for i in axis: x = torch.prod(x, i, keepdim=keepdims, dtype=dtype) return x diff --git a/ivy/functional/frontends/numpy/logic/truth_value_testing.py b/ivy/functional/frontends/numpy/logic/truth_value_testing.py index 6af6804ff72fb..0421def6ee57e 100644 --- a/ivy/functional/frontends/numpy/logic/truth_value_testing.py +++ b/ivy/functional/frontends/numpy/logic/truth_value_testing.py @@ -78,8 +78,17 @@ def isrealobj(x: any): @to_ivy_arrays_and_back def isscalar(element): - return ( - isinstance(element, (int, float, complex, bool, bytes, str, memoryview)) - or isinstance(element, numbers.Number) - or isinstance(element, np_frontend.generic) + return isinstance( + element, + ( + int, + float, + complex, + bool, + bytes, + str, + memoryview, + numbers.Number, + np_frontend.generic, + ), ) diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py b/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py index 27d42cbffa2ae..121c60e3a67fb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py @@ -311,7 +311,7 @@ def where(draw, *, shape=None): # noinspection PyShadowingNames def handle_where_and_array_bools(where, input_dtype, test_flags): - if isinstance(where, list) or isinstance(where, tuple): + if isinstance(where, (list, tuple)): where = where[0] test_flags.as_variable += [False] test_flags.native_arrays += [False] diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py index 822ccfe4b46ab..25df72b029fd5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_basic_operations.py @@ -42,7 +42,7 @@ def test_numpy_copyto( frontend, ): _, xs, casting, where = copyto_args - if isinstance(where, list) or isinstance(where, tuple): + if isinstance(where, (list, tuple)): where = where[0] with BackendHandler.update_backend(backend_fw) as ivy_backend: diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py b/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py index 23ada8a7c96fe..9f46084fede6d 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_dtype.py @@ -300,11 +300,14 @@ def test_default_complex_dtype( complex_dtype=complex_dtype[0], as_native=as_native, ) - assert ( - isinstance(res, ivy_backend.Dtype) - or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) - or isinstance(res, ivy_backend.NativeDtype) - or isinstance(res, str) + assert isinstance( + res, + ( + ivy_backend.Dtype, + typing.get_args(ivy_backend.NativeDtype), + ivy_backend.NativeDtype, + str, + ), ) assert ( ivy_backend.default_complex_dtype( @@ -362,11 +365,14 @@ def test_default_float_dtype( float_dtype=float_dtype[0], as_native=as_native, ) - assert ( - isinstance(res, ivy_backend.Dtype) - or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) - or isinstance(res, ivy_backend.NativeDtype) - or isinstance(res, str) + assert isinstance( + res, + ( + ivy_backend.Dtype, + typing.get_args(ivy_backend.NativeDtype), + ivy_backend.NativeDtype, + str, + ), ) assert ( ivy_backend.default_float_dtype( @@ -401,11 +407,14 @@ def test_default_int_dtype( int_dtype=int_dtype[0], as_native=as_native, ) - assert ( - isinstance(res, ivy_backend.Dtype) - or isinstance(res, typing.get_args(ivy_backend.NativeDtype)) - or isinstance(res, ivy_backend.NativeDtype) - or isinstance(res, str) + assert isinstance( + res, + ( + ivy_backend.Dtype, + typing.get_args(ivy_backend.NativeDtype), + ivy_backend.NativeDtype, + str, + ), ) assert ( ivy_backend.default_int_dtype(input=None, int_dtype=None, as_native=False) diff --git a/ivy_tests/test_ivy/test_stateful/test_modules.py b/ivy_tests/test_ivy/test_stateful/test_modules.py index 020856341838d..1a361cba5fa42 100644 --- a/ivy_tests/test_ivy/test_stateful/test_modules.py +++ b/ivy_tests/test_ivy/test_stateful/test_modules.py @@ -403,7 +403,7 @@ def model_assert(mod, on_device): for key, obj in mod.v.items(): if isinstance(obj, ivy.Module): return model_assert(obj, on_device) - if isinstance(obj, ivy.Container) or isinstance(obj, dict): + if isinstance(obj, (ivy.Container, dict)): for item1, item2 in obj.items(): assertion(item2.device, on_device) @@ -411,7 +411,7 @@ def model_assert(mod, on_device): assertion(obj.device, on_device) if getattr(mod, "buffers", None): for key, obj in mod.buffers.items(): - if isinstance(obj, ivy.Container) or isinstance(obj, dict): + if isinstance(obj, (ivy.Container, dict)): ivy.nested_map(lambda x: assertion(x.device, on_device), obj) else: assertion(obj.device, on_device)