Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added complex dtype support for silu #23607

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,13 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
"""
return ivy.selu(self._data, out=out)

def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def silu(
self: ivy.Array,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.silu. This method simply wraps the
function, and so the docstring for ivy.silu also applies to this method with
Expand All @@ -248,7 +254,10 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
Parameters
----------
self
input array.
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand All @@ -260,7 +269,7 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([-0.26894143, 0. , 0.73105854])
"""
return ivy.silu(self._data, out=out)
return ivy.silu(self._data, complex_mode=complex_mode, out=out)

def elu(
self,
Expand Down
10 changes: 10 additions & 0 deletions ivy/data_classes/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,7 @@ def _static_silu(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -719,6 +720,9 @@ def _static_silu(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -746,6 +750,7 @@ def _static_silu(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand All @@ -757,6 +762,7 @@ def silu(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -779,6 +785,9 @@ def silu(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand All @@ -805,6 +814,7 @@ def silu(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def selu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return ret


def silu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def silu(
x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None
) -> JaxArray:
ret = jax.nn.silu(x)
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def selu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:


@_scalar_output_to_0d_array
def silu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def silu(
x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None
) -> np.ndarray:
ret = np.asarray(x * (1 / (1 + np.exp(-x))))
if ivy.exists(out):
return ivy.inplace_update(out, ret).astype(x.dtype)
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.
return F.selu(x.cast("float32")).cast(x.dtype)


def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
def silu(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.silu(x)
if paddle.is_complex(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def selu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
return ivy.astype(ret, x.dtype)


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def silu(
x: Tensor,
/,
*,
complex_mode="jax",
out: Optional[Tensor] = None,
) -> Tensor:
ret = tf.nn.silu(x)
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
def silu(
x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.nn.functional.silu(x)


Expand Down
5 changes: 3 additions & 2 deletions ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,14 @@ def sigmoid(x):


@with_supported_dtypes(
{"0.4.14 and below": ("complex", "float")},
{"0.4.14 and below": ("float",)},
"jax",
)
@to_ivy_arrays_and_back
def silu(x):
x = _type_conversion(x)
return ivy.multiply(x, ivy.sigmoid(x))
# return ivy.multiply(x, ivy.sigmoid(x))
return ivy.silu(x, complex_mode="jax")


@to_ivy_arrays_and_back
Expand Down
9 changes: 8 additions & 1 deletion ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,11 @@ def selu(
@handle_array_function
@handle_device_shifting
def silu(
Comment on lines 420 to 422
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed the @handle_complex_input decorator

x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the silu function element-wise.
Expand All @@ -429,6 +433,9 @@ def silu(
----------
x
input array.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down
17 changes: 13 additions & 4 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,17 @@ def _forward(self, x):


class SiLU(Module):
def __init__(self):
"""Apply the SiLU activation function."""
def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"):
"""
Apply the SiLU activation function.

Parameter
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""
self._complex_mode = complex_mode
Module.__init__(self)

def _forward(self, x):
Expand All @@ -291,7 +300,7 @@ def _forward(self, x):
ret
The outputs following the SiLU activation *[batch_shape, d]*
"""
return ivy.silu(x)
return ivy.silu(x, complex_mode=self._complex_mode)


class Sigmoid(Module):
Expand Down Expand Up @@ -357,7 +366,7 @@ def _forward(self, x):
class ReLU6(Module):
def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"):
"""
Apply the TANH activation function.
Apply the RELU6 activation function.

Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def test_jax_sigmoid(
@handle_frontend_test(
fn_tree="jax.nn.silu",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float_and_integer"),
available_dtypes=helpers.get_dtypes("numeric"),
large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ def test_selu(*, dtype_and_input, test_flags, backend_fw, fn_name, on_device):
@handle_test(
fn_tree="functional.ivy.experimental.silu",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
available_dtypes=helpers.get_dtypes("float_and_complex"),
large_abs_safety_factor=8,
small_abs_safety_factor=8,
safety_factor_scale="log",
),
complex_mode=st.sampled_from(["jax", "split", "magnitude"]),
)
def test_silu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
def test_silu(*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device):
dtype, x = dtype_and_x
helpers.test_function(
input_dtypes=dtype,
Expand All @@ -184,6 +185,7 @@ def test_silu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
rtol_=1e-02,
atol_=1e-02,
x=x[0],
complex_mode=complex_mode,
)


Expand Down
Loading