diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index e1814c0bafc2d..c7f631980841e 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -239,7 +239,13 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ return ivy.selu(self._data, out=out) - def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def silu( + self: ivy.Array, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ ivy.Array instance method variant of ivy.silu. This method simply wraps the function, and so the docstring for ivy.silu also applies to this method with @@ -248,7 +254,10 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: Parameters ---------- self - input array. + input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -260,7 +269,7 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: >>> print(y) ivy.array([-0.26894143, 0. , 0.73105854]) """ - return ivy.silu(self._data, out=out) + return ivy.silu(self._data, complex_mode=complex_mode, out=out) def elu( self, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index 09fe34c656347..94b85c48182e9 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -697,6 +697,7 @@ def _static_silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -719,6 +720,9 @@ def _static_silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -746,6 +750,7 @@ def _static_silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -757,6 +762,7 @@ def silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -779,6 +785,9 @@ def silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -805,6 +814,7 @@ def silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index 8c0799d49aeaf..629fbab5a86e4 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -64,7 +64,9 @@ def selu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return ret -def silu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def silu( + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None +) -> JaxArray: ret = jax.nn.silu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index cbd29ab5035e5..8834caecf811a 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -75,7 +75,9 @@ def selu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @_scalar_output_to_0d_array -def silu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def silu( + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None +) -> np.ndarray: ret = np.asarray(x * (1 / (1 + np.exp(-x)))) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 46e095a18bb02..d8f18ddca3116 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -102,7 +102,9 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. return F.selu(x.cast("float32")).cast(x.dtype) -def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def silu( + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.silu(x) if paddle.is_complex(x): diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index a2fb8d19361cd..dca469c540763 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -59,11 +59,11 @@ def selu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: return ivy.astype(ret, x.dtype) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def silu( x: Tensor, /, *, + complex_mode="jax", out: Optional[Tensor] = None, ) -> Tensor: ret = tf.nn.silu(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index eca54bb3612c6..145e38ec3b8bc 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -58,7 +58,9 @@ def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def silu( + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None +) -> torch.Tensor: return torch.nn.functional.silu(x) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 9bb8090a991ac..3f719e6fddd68 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -291,13 +291,14 @@ def sigmoid(x): @with_supported_dtypes( - {"0.4.14 and below": ("complex", "float")}, + {"0.4.14 and below": ("float",)}, "jax", ) @to_ivy_arrays_and_back def silu(x): x = _type_conversion(x) - return ivy.multiply(x, ivy.sigmoid(x)) + # return ivy.multiply(x, ivy.sigmoid(x)) + return ivy.silu(x, complex_mode="jax") @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 7886c353de159..103462704220b 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -420,7 +420,11 @@ def selu( @handle_array_function @handle_device_shifting def silu( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the silu function element-wise. @@ -429,6 +433,9 @@ def silu( ---------- x input array. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 218941bfb37fd..ba7a01fbf470c 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -274,8 +274,17 @@ def _forward(self, x): class SiLU(Module): - def __init__(self): - """Apply the SiLU activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the SiLU activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ + self._complex_mode = complex_mode Module.__init__(self) def _forward(self, x): @@ -291,7 +300,7 @@ def _forward(self, x): ret The outputs following the SiLU activation *[batch_shape, d]* """ - return ivy.silu(x) + return ivy.silu(x, complex_mode=self._complex_mode) class Sigmoid(Module): @@ -357,7 +366,7 @@ def _forward(self, x): class ReLU6(Module): def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): """ - Apply the TANH activation function. + Apply the RELU6 activation function. Parameters ---------- diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 5c2d7545cfc16..5f106cd193862 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -674,7 +674,7 @@ def test_jax_sigmoid( @handle_frontend_test( fn_tree="jax.nn.silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 137deab287f8e..1619f015579d8 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -167,13 +167,14 @@ def test_selu(*, dtype_and_input, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", ), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) -def test_silu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_silu(*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -184,6 +185,7 @@ def test_silu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): rtol_=1e-02, atol_=1e-02, x=x[0], + complex_mode=complex_mode, )