Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: nested maps on torch.compile for torch tensors #23108

Merged
merged 2 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ivy/data_classes/array/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _to_ivy(x: Any) -> Any:
def to_ivy(
x: Union[ivy.Array, ivy.NativeArray, Iterable],
nested: bool = False,
include_derived: Optional[Dict[type, bool]] = None,
include_derived: Optional[Dict[str, bool]] = None,
) -> Union[ivy.Array, ivy.NativeArray, Iterable]:
"""
Return the input array converted to an ivy.Array instance if it is a native array
Expand Down Expand Up @@ -84,7 +84,7 @@ def to_ivy(

def args_to_ivy(
*args: Iterable[Any],
include_derived: Optional[Dict[type, bool]] = None,
include_derived: Optional[Dict[str, bool]] = None,
**kwargs: Dict[str, Any],
) -> Tuple[Iterable[Any], Dict[str, Any]]:
"""
Expand Down Expand Up @@ -115,7 +115,7 @@ def args_to_ivy(
def to_native(
x: Union[ivy.Array, ivy.NativeArray, Iterable],
nested: bool = False,
include_derived: Optional[Dict[type, bool]] = None,
include_derived: Optional[Dict[str, bool]] = None,
cont_inplace: bool = False,
to_ignore: Optional[Union[type, Tuple[type]]] = None,
) -> Union[ivy.Array, ivy.NativeArray, Iterable]:
Expand Down Expand Up @@ -157,7 +157,7 @@ def to_native(

def args_to_native(
*args: Iterable[Any],
include_derived: Dict[type, bool] = None,
include_derived: Dict[str, bool] = None,
cont_inplace: bool = False,
to_ignore: Optional[Union[type, Tuple[type]]] = None,
**kwargs: Dict[str, Any],
Expand Down
8 changes: 4 additions & 4 deletions ivy/data_classes/container/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class _ContainerWithConversions(ContainerBase):
def _static_to_native(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
nested: Union[bool, ivy.Container] = False,
include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None,
include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
Expand Down Expand Up @@ -78,7 +78,7 @@ def _static_to_native(
def to_native(
self: ivy.Container,
nested: Union[bool, ivy.Container] = False,
include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None,
include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
Expand Down Expand Up @@ -138,7 +138,7 @@ def to_native(
def _static_to_ivy(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
nested: Union[bool, ivy.Container] = False,
include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None,
include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
Expand Down Expand Up @@ -199,7 +199,7 @@ def _static_to_ivy(
def to_ivy(
self: ivy.Container,
nested: Union[bool, ivy.Container] = False,
include_derived: Optional[Union[Dict[type, bool], ivy.Container]] = None,
include_derived: Optional[Union[Dict[str, bool], ivy.Container]] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
Expand Down
6 changes: 3 additions & 3 deletions ivy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _inputs_to_ivy_arrays(*args, **kwargs):
has_out = True
# convert all arrays in the inputs to ivy.Array instances
ivy_args, ivy_kwargs = ivy.args_to_ivy(
*args, **kwargs, include_derived={tuple: True}
*args, **kwargs, include_derived={"tuple": True}
)
if has_out:
ivy_kwargs["out"] = out
Expand Down Expand Up @@ -564,7 +564,7 @@ def _outputs_to_ivy_arrays(*args, **kwargs):
ret = fn(*args, **kwargs)
# convert all arrays in the return to `ivy.Array` instances
return (
ivy.to_ivy(ret, nested=True, include_derived={tuple: True})
ivy.to_ivy(ret, nested=True, include_derived={"tuple": True})
if ivy.array_mode
else ret
)
Expand Down Expand Up @@ -594,7 +594,7 @@ def output_to_native_arrays(fn: Callable) -> Callable:
@functools.wraps(fn)
def _output_to_native_arrays(*args, **kwargs):
ret = fn(*args, **kwargs)
return ivy.to_native(ret, nested=True, include_derived={tuple: True})
return ivy.to_native(ret, nested=True, include_derived={"tuple": True})

_output_to_native_arrays.outputs_to_native_arrays = True
return _output_to_native_arrays
Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/frontends/jax/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def _inputs_to_ivy_arrays_jax(*args, **kwargs):
has_out = True
# convert all arrays in the inputs to ivy.Array instances
new_args = ivy.nested_map(
args, _to_ivy_array, include_derived={tuple: True}, shallow=False
args, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
new_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False
kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
# add the original out argument back to the keyword arguments
if has_out:
Expand Down Expand Up @@ -153,10 +153,10 @@ def _outputs_to_frontend_arrays_jax(*args, **kwargs):
return _from_ivy_array_to_jax_frontend_array_weak_type(
ret,
nested=True,
include_derived={tuple: True},
include_derived={"tuple": True},
)
return _from_ivy_array_to_jax_frontend_array(
ret, nested=True, include_derived={tuple: True}
ret, nested=True, include_derived={"tuple": True}
)

return _outputs_to_frontend_arrays_jax
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/frontends/mxnet/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def _inputs_to_ivy_arrays_mxnet(*args, **kwargs):
"""
# convert all arrays in the inputs to ivy.Array instances
new_args = ivy.nested_map(
args, _to_ivy_array, include_derived={tuple: True}, shallow=False
args, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
new_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False
kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
return fn(*new_args, **new_kwargs)

Expand All @@ -105,7 +105,7 @@ def _outputs_to_frontend_arrays_mxnet(*args, **kwargs):
ret = fn(*args, **kwargs)

# convert all arrays in the return to `frontend.Tensorflow.tensor` instances
return ivy.nested_map(ret, _ivy_array_to_mxnet, include_derived={tuple: True})
return ivy.nested_map(ret, _ivy_array_to_mxnet, include_derived={"tuple": True})

_outputs_to_frontend_arrays_mxnet.outputs_to_frontend_arrays = True
return _outputs_to_frontend_arrays_mxnet
Expand Down
10 changes: 5 additions & 5 deletions ivy/functional/frontends/numpy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _set_order(args, order):
)
if order in ["K", "A", None]:
check_order = ivy.nested_map(
args, _check_C_order, include_derived={tuple: True}, shallow=False
args, _check_C_order, include_derived={"tuple": True}, shallow=False
)
if all(v is None for v in check_order) or any(
ivy.multi_index_nest(check_order, ivy.all_nested_indices(check_order))
Expand Down Expand Up @@ -447,9 +447,9 @@ def _inputs_to_ivy_arrays_np(*args, **kwargs):
The return of the function, with ivy arrays passed in the arguments.
"""
# convert all arrays in the inputs to ivy.Array instances
ivy_args = ivy.nested_map(args, _to_ivy_array, include_derived={tuple: True})
ivy_args = ivy.nested_map(args, _to_ivy_array, include_derived={"tuple": True})
ivy_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={tuple: True}
kwargs, _to_ivy_array, include_derived={"tuple": True}
)
return fn(*ivy_args, **ivy_kwargs)

Expand Down Expand Up @@ -509,10 +509,10 @@ def _outputs_to_frontend_arrays(*args, order="K", **kwargs):
# convert all returned arrays to `ndarray` instances
if order == "F":
return ivy.nested_map(
ret, _ivy_to_numpy_order_F, include_derived={tuple: True}
ret, _ivy_to_numpy_order_F, include_derived={"tuple": True}
)
else:
return ivy.nested_map(ret, _ivy_to_numpy, include_derived={tuple: True})
return ivy.nested_map(ret, _ivy_to_numpy, include_derived={"tuple": True})

if "order" in list(inspect.signature(fn).parameters.keys()):
contains_order = True
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/frontends/onnx/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def _inputs_to_ivy_arrays_onnx(*args, **kwargs):
"""
# convert all arrays in the inputs to ivy.Array instances
new_args = ivy.nested_map(
args, _to_ivy_array, include_derived={tuple: True}, shallow=False
args, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
new_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False
kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
return fn(*new_args, **new_kwargs)

Expand All @@ -82,7 +82,7 @@ def _outputs_to_frontend_arrays_onnx(*args, **kwargs):

# convert all arrays in the return to `frontend.onnx.Tensor` instances
return _from_ivy_array_to_onnx_frontend_tensor(
ret, nested=True, include_derived={tuple: True}
ret, nested=True, include_derived={"tuple": True}
)

return _outputs_to_frontend_arrays_onnx
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/frontends/paddle/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def new_fn(*args, **kwargs):
"""
# convert all input arrays to ivy.Array instances
new_args = ivy.nested_map(
args, _to_ivy_array, include_derived={tuple: True}, shallow=False
args, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
new_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False
kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)

return fn(*new_args, **new_kwargs)
Expand Down Expand Up @@ -82,7 +82,7 @@ def new_fn(*args, **kwargs):
ivy.unset_default_float_dtype()
# convert all arrays in the return to `paddle_frontend.Tensor` instances
return _from_ivy_array_to_paddle_frontend_tensor(
ret, nested=True, include_derived={tuple: True}
ret, nested=True, include_derived={"tuple": True}
)

return new_fn
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/tensorflow/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _outputs_to_frontend_arrays_tf(*args, **kwargs):

# convert all arrays in the return to `frontend.Tensorflow.tensor` instances
return ivy.nested_map(
ret, _ivy_array_to_tensorflow, include_derived={tuple: True}
ret, _ivy_array_to_tensorflow, include_derived={"tuple": True}
)

_outputs_to_frontend_arrays_tf.outputs_to_frontend_arrays = True
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/frontends/torch/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def _inputs_to_ivy_arrays_torch(*args, **kwargs):
)
# convert all input arrays to ivy.Array instances
new_args = ivy.nested_map(
args, _to_ivy_array, include_derived={tuple: True}, shallow=False
args, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
new_kwargs = ivy.nested_map(
kwargs, _to_ivy_array, include_derived={tuple: True}, shallow=False
kwargs, _to_ivy_array, include_derived={"tuple": True}, shallow=False
)
return fn(*new_args, **new_kwargs)

Expand Down Expand Up @@ -202,7 +202,7 @@ def outputs_to_frontend_arrays_torch(*args, **kwargs):
ret = _from_ivy_array_to_torch_frontend_tensor(
ret,
nested=True,
include_derived={tuple: True},
include_derived={"tuple": True},
requires_grad=kwargs.get(
"requires_grad",
any(
Expand Down
14 changes: 7 additions & 7 deletions ivy/functional/ivy/nest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ def nested_map(
x: Union[ivy.Array, ivy.NativeArray, Iterable],
/,
fn: Callable,
include_derived: Optional[Union[Dict[type, bool], bool]] = None,
include_derived: Optional[Union[Dict[str, bool], bool]] = None,
to_ignore: Optional[Union[type, Tuple[type]]] = None,
to_mutable: bool = False,
max_depth: Optional[int] = None,
Expand Down Expand Up @@ -1160,12 +1160,13 @@ def nested_map(
to_ignore = ivy.default(to_ignore, ())
extra_nest_types = ivy.default(extra_nest_types, ())
if include_derived is True:
include_derived = {tuple: True, list: True, dict: True}
include_derived = {"tuple": True, "list": True, "dict": True}
elif not include_derived:
include_derived = {}
for t in (tuple, list, dict):
for t in ("tuple", "list", "dict"):
if t not in include_derived:
include_derived[t] = False
# to ensure all keys are strings
if ivy.exists(max_depth) and _depth > max_depth:
return x
class_instance = type(x)
Expand All @@ -1182,27 +1183,26 @@ def nested_map(
_tuple_check_fn,
(
(lambda x_, t_: isinstance(x_, t_))
if include_derived[tuple]
if include_derived["tuple"]
else (lambda x_, t_: type(x_) is t_)
),
)
list_check_fn = ivy.default(
_list_check_fn,
(
(lambda x_, t_: isinstance(x_, t_))
if include_derived[list]
if include_derived["list"]
else (lambda x_, t_: type(x_) is t_)
),
)
dict_check_fn = ivy.default(
_dict_check_fn,
(
(lambda x_, t_: isinstance(x_, t_))
if include_derived[dict]
if include_derived["dict"]
else (lambda x_, t_: type(x_) is t_)
),
)

if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore):
ret_list = [
nested_map(
Expand Down
Loading