Skip to content

Commit

Permalink
Add torch.nn.Parameter, torch.nn.Module and ivy.Module.to_keras_modul…
Browse files Browse the repository at this point in the history
…e for KLA
  • Loading branch information
hmahmood24 committed Nov 4, 2023
1 parent 9d93746 commit 6daeb0e
Show file tree
Hide file tree
Showing 9 changed files with 1,459 additions and 0 deletions.
121 changes: 121 additions & 0 deletions ivy/data_classes/array/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@
# --------#


# TODO: Need to find a better way to do this.
# Temporarily adding as is for the
# `ivy.Module.to_keras_module`method
# for the KLA demo. Do not move/remove.
ARRAY_TO_BACKEND = {
"ndarray": "numpy",
"Tensor": ["torch", "paddle"],
"Parameter": "torch",
"EagerTensor": "tensorflow",
"ResourceVariable": "tensorflow",
"DeviceArray": "jax",
"Array": "jax",
"ArrayImpl": "jax",
"EagerParamBase": "paddle",
}


def _to_native(x: Any, inplace: bool = False, to_ignore: tuple = ()) -> Any:
to_ignore = ivy.default(to_ignore, ())
if isinstance(x, to_ignore):
Expand Down Expand Up @@ -200,3 +217,107 @@ def args_to_native(
shallow=False,
)
return native_args, native_kwargs


# TODO: Need to find a better way to do this.
# Temporarily adding as is for the
# `ivy.Module.to_keras_module`method
# for the . Do not move/remove.
def array_to_new_backend(
x: Union[ivy.Array, ivy.NativeArray],
native: bool = False,
) -> Union[ivy.Array, ivy.NativeArray]:
native_x = x.data if isinstance(x, ivy.Array) else x
native_x_type = type(native_x).__name__

# Modify native_type here since @tf.function converts tf.EagerTensor into
# tf.Tensor when running @tf.function on a transpiled graph
if ivy.current_backend_str() == "tensorflow":
import importlib

native_x_type = (
"EagerTensor"
if not importlib.import_module("tensorflow").executing_eagerly()
and isinstance(native_x, importlib.import_module("tensorflow").Tensor)
else native_x_type
)

# Check for paddle first, as it shares the 'Tensor' native_x_type with torch
if "paddle" in str(native_x.__class__) and ivy.current_backend_str() == "paddle":
if native:
return native_x
else:
return x

if hasattr(x, "_ivy_array"):
return x

# Check if the other possible backends match with the native data type
if (
native_x_type in ARRAY_TO_BACKEND
and ivy.current_backend_str() in ARRAY_TO_BACKEND[native_x_type]
):
if ivy.current_backend_str() == "torch":
if "torch" in str(native_x.__class__):
# torch and paddle both use 'Tensor', return if this is torch
return x
else:
# if it's actually a paddle tensor, convert to an ivy array
ret = ivy.array(native_x.numpy())
return ret.data if native else ret
if ivy.current_backend_str() == "paddle":
if "paddle" in str(native_x.__class__):
# torch and paddle both use 'Tensor', return if this is paddle
return x
else:
# if it's actually a torch tensor, convert to an ivy array
ret = ivy.array(native_x.numpy())
return ret.data if native else ret
return x

if native_x_type not in ARRAY_TO_BACKEND:
return x
native_x = (
native_x.detach().cpu()
if native_x_type in ["Parameter", "Tensor"]
else native_x
)
np_intermediary = np.array(native_x)
ret = ivy.array(np_intermediary)
return ret.data if native else ret


# TODO: Need to find a better way to do this.
# Temporarily adding as is for the
# `ivy.Module.to_keras_module()`method
# for the KLA demo. Do not move/remove.
def nest_array_to_new_backend(
nest: Iterable[Union[ivy.Array, ivy.NativeArray]],
native: bool = True,
shallow: bool = True,
) -> Iterable[Union[ivy.Array, ivy.NativeArray]]:
"""
Convert a given ivy.Array or ivy.NativeArray to a new backend framework.
Parameters
----------
nest
Input nest with the leaves to be converted to a new backend.
native
Whether to return the new array as a ivy.NativeArray or an ivy.Array.
Default is ``True``.
shallow
Whether to inplace update the input nest or not
Only works if nest is a mutable type. Default is ``True``.
Returns
-------
ret
The input nest with leaves converted to the new backend framework.
"""
return ivy.nested_map(
lambda x: array_to_new_backend(x, native=native),
nest,
include_derived=True,
shallow=shallow,
)
1 change: 1 addition & 0 deletions ivy/functional/frontends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def promote_types_of_torch_inputs(
return x1, x2


from . import utils
from . import nn
from .nn.functional import softmax, relu, lstm
from . import tensor
Expand Down
4 changes: 4 additions & 0 deletions ivy/functional/frontends/torch/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from . import functional
from . import modules
from .modules import *
from . import parameter
from .parameter import Parameter
2 changes: 2 additions & 0 deletions ivy/functional/frontends/torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import module
from .module import Module
Loading

0 comments on commit 6daeb0e

Please sign in to comment.