Skip to content

Commit

Permalink
(refactor): pass fn as first argument to nested_map to follow convent…
Browse files Browse the repository at this point in the history
…ion and remove redundant logic from nested functions (ivy-llc#23538)
  • Loading branch information
mattbarrett98 authored and druvdub committed Oct 14, 2023
1 parent 469125d commit 50a9a7e
Show file tree
Hide file tree
Showing 32 changed files with 156 additions and 329 deletions.
12 changes: 6 additions & 6 deletions ivy/data_classes/array/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def to_ivy(
the input in its native framework form in the case of ivy.Array or instances.
"""
if nested:
return ivy.nested_map(x, _to_ivy, include_derived, shallow=False)
return ivy.nested_map(_to_ivy, x, include_derived, shallow=False)
return _to_ivy(x)


Expand Down Expand Up @@ -107,8 +107,8 @@ def args_to_ivy(
the same arguments, with any nested arrays converted to ivy.Array or
instances.
"""
native_args = ivy.nested_map(args, _to_ivy, include_derived, shallow=False)
native_kwargs = ivy.nested_map(kwargs, _to_ivy, include_derived, shallow=False)
native_args = ivy.nested_map(_to_ivy, args, include_derived, shallow=False)
native_kwargs = ivy.nested_map(_to_ivy, kwargs, include_derived, shallow=False)
return native_args, native_kwargs


Expand Down Expand Up @@ -147,8 +147,8 @@ def to_native(
"""
if nested:
return ivy.nested_map(
x,
lambda x: _to_native(x, inplace=cont_inplace, to_ignore=to_ignore),
x,
include_derived,
shallow=False,
)
Expand Down Expand Up @@ -188,14 +188,14 @@ def args_to_native(
native form.
"""
native_args = ivy.nested_map(
args,
lambda x: _to_native(x, inplace=cont_inplace, to_ignore=to_ignore),
args,
include_derived,
shallow=False,
)
native_kwargs = ivy.nested_map(
kwargs,
lambda x: _to_native(x, inplace=cont_inplace, to_ignore=to_ignore),
kwargs,
include_derived,
shallow=False,
)
Expand Down
2 changes: 1 addition & 1 deletion ivy/data_classes/container/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3265,7 +3265,7 @@ def cont_map(
return_dict[key] = ret
elif isinstance(value, (list, tuple)) and map_sequences:
ret = ivy.nested_map(
value, lambda x: func(x, None), True, shallow=False
lambda x: func(x, None), value, True, shallow=False
)
if prune_unapplied and not ret:
continue
Expand Down
2 changes: 1 addition & 1 deletion ivy/data_classes/nested_array/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def broadcast_shapes(shapes):

def ragged_map(self, fn):
arg = ivy.copy_nest(self._data)
ivy.nested_map(arg, lambda x: fn(x), shallow=True)
ivy.nested_map(lambda x: fn(x), arg, shallow=True)
# infer dtype, shape, and device from the first array in the ret data
arr0_id = ivy.nested_argwhere(arg, ivy.is_ivy_array, stop_after_n_found=1)[0]
arr0 = ivy.index_nest(arg, arr0_id)
Expand Down
12 changes: 6 additions & 6 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ def inputs_to_native_shapes(fn: Callable) -> Callable:
@functools.wraps(fn)
def _inputs_to_native_shapes(*args, **kwargs):
args, kwargs = ivy.nested_map(
[args, kwargs],
lambda x: (x.shape if isinstance(x, ivy.Shape) and ivy.array_mode else x),
[args, kwargs],
)
return fn(*args, **kwargs)

Expand All @@ -521,8 +521,8 @@ def outputs_to_ivy_shapes(fn: Callable) -> Callable:
@functools.wraps(fn)
def _outputs_to_ivy_shapes(*args, **kwargs):
args, kwargs = ivy.nested_map(
[args, kwargs],
lambda x: (x.shape if isinstance(x, ivy.Shape) and ivy.array_mode else x),
[args, kwargs],
)
return fn(*args, **kwargs)

Expand Down Expand Up @@ -635,8 +635,8 @@ def frontend_outputs_to_ivy_arrays(fn: Callable) -> Callable:
def _outputs_to_ivy_arrays(*args, **kwargs):
ret = fn(*args, **kwargs)
return ivy.nested_map(
ret,
lambda x: x.ivy_array if hasattr(x, "ivy_array") else x,
ret,
shallow=False,
)

Expand Down Expand Up @@ -1194,8 +1194,8 @@ def mini_helper(x):
x = ivy.to_native(ivy.astype(x, ivy.as_native_dtype(dtype)))
return x

args = ivy.nested_map(args, mini_helper, include_derived=True)
kwargs = ivy.nested_map(kwargs, mini_helper)
args = ivy.nested_map(mini_helper, args, include_derived=True)
kwargs = ivy.nested_map(mini_helper, kwargs)
return fn(*args, **kwargs)

return method
Expand Down Expand Up @@ -1589,7 +1589,7 @@ def func(x):
)
return x

ivy.nested_map(array_vals, func, include_derived=True)
ivy.nested_map(func, array_vals, include_derived=True)

return fn(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/jax/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _forward_fn(
):
"""Forward function for gradient calculation."""
# Setting x(relevant variables) into xs(all variables)
x = ivy.nested_map(x, ivy.to_ivy, include_derived=True)
x = ivy.nested_map(ivy.to_ivy, x, include_derived=True)
x_arr_idxs = ivy.nested_argwhere(x, ivy.is_array)
x_arr_values = ivy.multi_index_nest(x, x_arr_idxs)
if xs_grad_idxs is not None:
Expand Down Expand Up @@ -131,7 +131,7 @@ def value_and_grad(func):
grad_fn = lambda xs: ivy.to_native(func(xs))

def callback_fn(xs):
xs = ivy.nested_map(xs, lambda x: ivy.to_native(x), include_derived=True)
xs = ivy.nested_map(lambda x: ivy.to_native(x), xs, include_derived=True)
value, grad = jax.value_and_grad(grad_fn)(xs)
return ivy.to_ivy(value), ivy.to_ivy(grad)

Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/numpy/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def value_and_grad(func):

def grad_fn(xs):
grads = ivy.nested_map(
xs, lambda x: ivy.zeros_like(x), include_derived=True, shallow=False
lambda x: ivy.zeros_like(x), xs, include_derived=True, shallow=False
)
y = func(xs)
y = ivy.to_ivy(y)
Expand All @@ -67,7 +67,7 @@ def jac(func):

def grad_fn(xs):
jacobian = ivy.nested_map(
xs, lambda x: ivy.zeros_like(x), include_derived=True, shallow=False
lambda x: ivy.zeros_like(x), xs, include_derived=True, shallow=False
)
return jacobian

Expand All @@ -82,7 +82,7 @@ def grad(func, argnums=0):

def grad_fn(xs):
grad = ivy.nested_map(
xs, lambda x: ivy.zeros_like(x), include_derived=True, shallow=False
lambda x: ivy.zeros_like(x), xs, include_derived=True, shallow=False
)
y = func(xs)
y = ivy.to_ivy(y)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/paddle/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def asarray(
return paddle_backend.squeeze(
paddle.to_tensor(obj, dtype=dtype, place=device), axis=0
)
obj = ivy.nested_map(obj, _remove_np_bfloat16, shallow=False)
obj = ivy.nested_map(_remove_np_bfloat16, obj, shallow=False)
return paddle.to_tensor(obj, dtype=dtype, place=device)


Expand Down
12 changes: 6 additions & 6 deletions ivy/functional/backends/paddle/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def _grad_func(y, xs, retain_grads):
"""Gradient calculation function."""
# Creating a zero gradient nest for the case where no gradients are computed
grads_ = ivy.nested_map(
xs,
lambda x: (paddle.to_tensor([0.0]) if x is None else paddle.zeros_like(x)),
xs,
include_derived=True,
shallow=False,
)
Expand Down Expand Up @@ -78,7 +78,7 @@ def _grad_func(y, xs, retain_grads):
# Returning zeros if no gradients are computed for consistent results
if isinstance(grads, ivy.Container):
grads = ivy.nested_map(
grads, lambda x: 0 if x is None else x, include_derived=True
lambda x: 0 if x is None else x, grads, include_derived=True
)
grads = ivy.add(grads, grads_)
else:
Expand All @@ -96,7 +96,7 @@ def grad_(x):
)[0]
return grad if grad is not None else paddle.zeros_like(x)

grads = ivy.nested_map(xs, grad_, include_derived=True, shallow=False)
grads = ivy.nested_map(grad_, xs, include_derived=True, shallow=False)
grads = ivy.nested_multi_map(
lambda x, _: (paddle_backend.add(x[0], x[1])), [grads, grads_]
)
Expand Down Expand Up @@ -172,7 +172,7 @@ def autograd_fn(x):
grad = ivy.to_ivy(grad)
return grad

grads = ivy.nested_map(xs, autograd_fn, include_derived=True, shallow=False)
grads = ivy.nested_map(autograd_fn, xs, include_derived=True, shallow=False)
y = ivy.to_ivy(y)
return y, grads

Expand Down Expand Up @@ -215,14 +215,14 @@ def one_out_fn(o):
out_shape = ivy.index_nest(grad_fn(xs), out_idx).shape
one_arg_fn = _get_jac_one_arg_fn(grad_fn, xs, out_idx)
jacobian = ivy.nested_map(
xs,
lambda x: jacobian_to_ivy(
paddle.incubate.autograd.Jacobian(
one_arg_fn, ivy.to_native(x.expand_dims())
),
x.shape,
out_shape,
),
xs,
shallow=False,
)
return jacobian
Expand All @@ -247,7 +247,7 @@ def jac(func: Callable):
def callback_fn(xs):
fn_ret = grad_fn(xs)
one_out_fn = _get_one_out_fn(grad_fn, xs, fn_ret)
jacobian = ivy.nested_map(fn_ret, one_out_fn)
jacobian = ivy.nested_map(one_out_fn, fn_ret)
return jacobian

return callback_fn
Expand Down
16 changes: 8 additions & 8 deletions ivy/functional/backends/tensorflow/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def _grad_func(y, xs, xs_required, tape):
"""Gradient calculation function."""
# Creating a zero gradient nest for the case where no gradients are computed
grads_ = ivy.nested_map(
xs_required,
lambda x: ivy.to_native(ivy.zeros_like(x)),
xs_required,
include_derived=True,
shallow=False,
)
Expand All @@ -52,8 +52,8 @@ def _grad_func(y, xs, xs_required, tape):
grads = grads_ if grads is None else grads
else:
grads = ivy.nested_map(
grads,
lambda x: 0 if x is None else x,
grads,
include_derived=True,
)
if isinstance(grads, ivy.Container):
Expand Down Expand Up @@ -97,8 +97,8 @@ def execute_with_gradients(
# Gradient calculation for multiple outputs
y = _get_native_y(y)
grads_ = ivy.nested_map(
y,
lambda x: _grad_func(x, xs, xs_required, tape),
y,
include_derived=True,
shallow=False,
)
Expand All @@ -120,17 +120,17 @@ def execute_with_gradients(
def value_and_grad(func):
def grad_fn(xs):
grads = ivy.nested_map(
xs, lambda x: ivy.zeros_like(x), include_derived=True, shallow=False
lambda x: ivy.zeros_like(x), xs, include_derived=True, shallow=False
)
with tf.GradientTape(watch_accessed_variables=False) as tape:
xs = ivy.nested_map(xs, lambda x: ivy.to_native(x), include_derived=True)
xs = ivy.nested_map(lambda x: ivy.to_native(x), xs, include_derived=True)
tape.watch(xs)
y = func(xs)
y = y.to_native(y)
grads_ = tape.gradient(y, xs)
grads_ = ivy.nested_map(
grads_,
lambda x: ivy.to_ivy(x),
grads_,
include_derived=True,
)
grads_ = ivy.to_ivy(grads_)
Expand Down Expand Up @@ -170,19 +170,19 @@ def jac(func: Callable):

def callback_fn(x_in):
with tf.GradientTape(persistent=True) as tape:
ivy.nested_map(x_in, ivy.copy_array)
ivy.nested_map(ivy.copy_array, x_in)
x_in = ivy.to_native(x_in, nested=True)
tape.watch(x_in)
y = grad_fn(x_in)

# Deal with multiple outputs
if not isinstance(y, ivy.NativeArray):
jacobian = ivy.nested_map(
y,
lambda yi: ivy.to_ivy(
tape.jacobian(yi, x_in, unconnected_gradients="zero"),
nested=True,
),
y,
include_derived=True,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/backends/torch/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def asarray(
device: torch.device,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
obj = ivy.nested_map(obj, _remove_np_bfloat16, shallow=False)
obj = ivy.nested_map(_remove_np_bfloat16, obj, shallow=False)
if isinstance(obj, Sequence) and len(obj) != 0:
contain_tensor = ivy.nested_any(obj, lambda x: isinstance(x, torch.Tensor))
# if `obj` is a list of specifically tensors or
Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/backends/torch/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _grad_func(y, xs, retain_grads):
"""Gradient calculation function."""
# Creating a zero gradient nest for the case where no gradients are computed
grads_ = ivy.nested_map(
xs,
lambda x: ivy.to_native(ivy.zeros_like(x)),
xs,
include_derived=True,
shallow=False,
)
Expand Down Expand Up @@ -70,7 +70,7 @@ def _grad_func(y, xs, retain_grads):
# Returning zeros if no gradients are computed for consistent results
if isinstance(grads, ivy.Container):
grads = ivy.nested_map(
grads, lambda x: 0 if x is None else x, include_derived=True
lambda x: 0 if x is None else x, grads, include_derived=True
)
grads += grads_
else:
Expand All @@ -87,7 +87,7 @@ def grad_(x):
)[0]
return grad if grad is not None else 0

grads = ivy.nested_map(xs, grad_, include_derived=True, shallow=False)
grads = ivy.nested_map(grad_, xs, include_derived=True, shallow=False)
grads = ivy.nested_multi_map(lambda x, _: (x[0] + x[1]), [grads, grads_])
return grads

Expand Down Expand Up @@ -157,7 +157,7 @@ def autograd_fn(x):
grad = ivy.to_ivy(grad)
return grad

grads = ivy.nested_map(xs, autograd_fn, include_derived=True, shallow=False)
grads = ivy.nested_map(autograd_fn, xs, include_derived=True, shallow=False)
y = ivy.to_ivy(y)
return y, grads

Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/frontends/jax/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def _from_ivy_array_to_jax_frontend_array(x, nested=False, include_derived=None):
if nested:
return ivy.nested_map(
x, _from_ivy_array_to_jax_frontend_array, include_derived, shallow=False
_from_ivy_array_to_jax_frontend_array, x, include_derived, shallow=False
)
elif isinstance(x, ivy.Array):
return jax_frontend.Array(x)
Expand All @@ -28,8 +28,8 @@ def _from_ivy_array_to_jax_frontend_array_weak_type(
):
if nested:
return ivy.nested_map(
x,
_from_ivy_array_to_jax_frontend_array_weak_type,
x,
include_derived,
shallow=False,
)
Expand Down Expand Up @@ -111,10 +111,10 @@ def _inputs_to_ivy_arrays_jax(*args, **kwargs):
has_out = True
# convert all arrays in the inputs to ivy.Array instances
new_args = ivy.nested_map(
args, _to_ivy_array, include_derived={"tuple": True}, shallow=False
_to_ivy_array, args, include_derived={"tuple": True}, shallow=False
)
new_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False
_to_ivy_array, kwargs, include_derived={"tuple": True}, shallow=False
)
# add the original out argument back to the keyword arguments
if has_out:
Expand Down
Loading

0 comments on commit 50a9a7e

Please sign in to comment.