Skip to content

Commit

Permalink
Merge isintace calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sai-Suraj-27 committed Oct 15, 2023
1 parent 4154713 commit 25af5e5
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 24 deletions.
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def prod(
return x.type(dtype)
if axis is None:
return torch.prod(input=x, dtype=dtype)
if isinstance(axis, tuple) or isinstance(axis, list):
if isinstance(axis, (tuple, list)):
for i in axis:
x = torch.prod(x, i, keepdim=keepdims, dtype=dtype)
return x
Expand Down
17 changes: 13 additions & 4 deletions ivy/functional/frontends/numpy/logic/truth_value_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,17 @@ def isrealobj(x: any):

@to_ivy_arrays_and_back
def isscalar(element):
return (
isinstance(element, (int, float, complex, bool, bytes, str, memoryview))
or isinstance(element, numbers.Number)
or isinstance(element, np_frontend.generic)
return isinstance(
element,
(
int,
float,
complex,
bool,
bytes,
str,
memoryview,
numbers.Number,
np_frontend.generic,
),
)
2 changes: 1 addition & 1 deletion ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def where(draw, *, shape=None):

# noinspection PyShadowingNames
def handle_where_and_array_bools(where, input_dtype, test_flags):
if isinstance(where, list) or isinstance(where, tuple):
if isinstance(where, (list, tuple)):
where = where[0]
test_flags.as_variable += [False]
test_flags.native_arrays += [False]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_numpy_copyto(
frontend,
):
_, xs, casting, where = copyto_args
if isinstance(where, list) or isinstance(where, tuple):
if isinstance(where, (list, tuple)):
where = where[0]

with BackendHandler.update_backend(backend_fw) as ivy_backend:
Expand Down
39 changes: 24 additions & 15 deletions ivy_tests/test_ivy/test_functional/test_core/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,14 @@ def test_default_complex_dtype(
complex_dtype=complex_dtype[0],
as_native=as_native,
)
assert (
isinstance(res, ivy_backend.Dtype)
or isinstance(res, typing.get_args(ivy_backend.NativeDtype))
or isinstance(res, ivy_backend.NativeDtype)
or isinstance(res, str)
assert isinstance(
res,
(
ivy_backend.Dtype,
typing.get_args(ivy_backend.NativeDtype),
ivy_backend.NativeDtype,
str,
),
)
assert (
ivy_backend.default_complex_dtype(
Expand Down Expand Up @@ -362,11 +365,14 @@ def test_default_float_dtype(
float_dtype=float_dtype[0],
as_native=as_native,
)
assert (
isinstance(res, ivy_backend.Dtype)
or isinstance(res, typing.get_args(ivy_backend.NativeDtype))
or isinstance(res, ivy_backend.NativeDtype)
or isinstance(res, str)
assert isinstance(
res,
(
ivy_backend.Dtype,
typing.get_args(ivy_backend.NativeDtype),
ivy_backend.NativeDtype,
str,
),
)
assert (
ivy_backend.default_float_dtype(
Expand Down Expand Up @@ -401,11 +407,14 @@ def test_default_int_dtype(
int_dtype=int_dtype[0],
as_native=as_native,
)
assert (
isinstance(res, ivy_backend.Dtype)
or isinstance(res, typing.get_args(ivy_backend.NativeDtype))
or isinstance(res, ivy_backend.NativeDtype)
or isinstance(res, str)
assert isinstance(
res,
(
ivy_backend.Dtype,
typing.get_args(ivy_backend.NativeDtype),
ivy_backend.NativeDtype,
str,
),
)
assert (
ivy_backend.default_int_dtype(input=None, int_dtype=None, as_native=False)
Expand Down
4 changes: 2 additions & 2 deletions ivy_tests/test_ivy/test_stateful/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,15 @@ def model_assert(mod, on_device):
for key, obj in mod.v.items():
if isinstance(obj, ivy.Module):
return model_assert(obj, on_device)
if isinstance(obj, ivy.Container) or isinstance(obj, dict):
if isinstance(obj, (ivy.Container, dict)):
for item1, item2 in obj.items():
assertion(item2.device, on_device)

else:
assertion(obj.device, on_device)
if getattr(mod, "buffers", None):
for key, obj in mod.buffers.items():
if isinstance(obj, ivy.Container) or isinstance(obj, dict):
if isinstance(obj, (ivy.Container, dict)):
ivy.nested_map(lambda x: assertion(x.device, on_device), obj)
else:
assertion(obj.device, on_device)
Expand Down

0 comments on commit 25af5e5

Please sign in to comment.