From 8997af3e455b29344b4cb6017950c5b176240b5d Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Thu, 14 Sep 2023 13:59:00 +0100 Subject: [PATCH] feat: add support for complex dtype to relu6 (#23599) --- .../array/experimental/activations.py | 13 +++++-- .../container/experimental/activations.py | 10 ++++++ .../backends/jax/experimental/activations.py | 4 ++- .../numpy/experimental/activations.py | 4 ++- .../paddle/experimental/activations.py | 7 +++- .../tensorflow/experimental/activations.py | 3 +- .../torch/experimental/activations.py | 4 ++- .../jax/nn/non_linear_activations.py | 2 +- ivy/functional/frontends/tensorflow/nn.py | 1 + .../ivy/experimental/activations.py | 35 ++++++++++++++++++- ivy/stateful/activations.py | 15 ++++++-- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 6 +++- 13 files changed, 91 insertions(+), 15 deletions(-) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index f27ef89516bbb..e1814c0bafc2d 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -123,7 +123,13 @@ def prelu( """ return ivy.prelu(self._data, slope, out=out) - def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def relu6( + self, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -131,6 +137,9 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: ---------- self input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -156,7 +165,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: >>> print(y) ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) """ - return ivy.relu6(self._data, out=out) + return ivy.relu6(self._data, complex_mode=complex_mode, out=out) def logsigmoid( self: ivy.Array, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index 082fb5e062b40..09fe34c656347 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -329,6 +329,7 @@ def static_relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -351,6 +352,9 @@ def static_relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -379,6 +383,7 @@ def static_relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -390,6 +395,7 @@ def relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -412,6 +418,9 @@ def relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -439,6 +448,7 @@ def relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index ccb239dcbc430..8c0799d49aeaf 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -23,7 +23,9 @@ def logit( return jnp.log(x / (1 - x)) -def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def relu6( + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None +) -> JaxArray: relu6_func = jax.nn.relu6 # sets gradient at 0 and 6 to 0 instead of 0.5 diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index 707ecd8a4f4df..cbd29ab5035e5 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -44,7 +44,9 @@ def thresholded_relu( @_scalar_output_to_0d_array -def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def relu6( + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None +) -> np.ndarray: return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 33f1d80bbf8ef..46e095a18bb02 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -54,7 +54,12 @@ def thresholded_relu( ) -def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version +) +def relu6( + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.relu6(x) if paddle.is_complex(x): diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index 709ae09314cd1..a2fb8d19361cd 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -38,8 +38,7 @@ def thresholded_relu( return tf.cast(tf.where(x > threshold, x, 0), x.dtype) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) -def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def relu6(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu6(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 5969804298c59..eca54bb3612c6 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -34,7 +34,9 @@ def thresholded_relu( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def relu6( + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None +) -> torch.Tensor: return torch.nn.functional.relu6(x) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index a040c1cd5c83e..9bb8090a991ac 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -273,7 +273,7 @@ def relu(x): @to_ivy_arrays_and_back def relu6(x): - res = ivy.relu6(x) + res = ivy.relu6(x, complex_mode="jax") return _type_conversion_64(res) diff --git a/ivy/functional/frontends/tensorflow/nn.py b/ivy/functional/frontends/tensorflow/nn.py index 64ef3aab71f0f..60ba268da02d6 100644 --- a/ivy/functional/frontends/tensorflow/nn.py +++ b/ivy/functional/frontends/tensorflow/nn.py @@ -440,6 +440,7 @@ def relu(features, name=None): return ivy.relu(features) +@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, "tensorflow") @to_ivy_arrays_and_back def relu6(features, name=None): return ivy.relu6(features) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index d4f701cf45b0d..7886c353de159 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -214,6 +214,28 @@ def thresholded_relu( return current_backend(x).thresholded_relu(x, threshold=threshold, out=out) +def _relu6_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + return ivy.where( + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ), + ivy.array(0, dtype=x.dtype), + ivy.where( + ivy.logical_or( + ivy.real(x) > 6, ivy.logical_and(ivy.real(x) == 6, ivy.imag(x) > 0) + ), + ivy.array(6, dtype=x.dtype), + x, + ), + ) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -222,8 +244,13 @@ def thresholded_relu( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def relu6( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -232,6 +259,9 @@ def relu6( ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -260,6 +290,9 @@ def relu6( return current_backend(x).relu6(x, out=out) +relu6.jax_like = _relu6_jax_like + + @handle_exceptions @handle_backend_invalid @handle_nestable diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 55074d5b3ebde..218941bfb37fd 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -355,8 +355,17 @@ def _forward(self, x): class ReLU6(Module): - def __init__(self): - """Apply the RELU6 activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the TANH activation function. + + Parameters + ---------- + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ + self._complex_mode = complex_mode Module.__init__(self) def _forward(self, x): @@ -372,7 +381,7 @@ def _forward(self, x): ret The outputs following the RELU6 activation *[batch_shape, d]* """ - return ivy.relu6(x) + return ivy.relu6(x, complex_mode=self._complex_mode) class Hardswish(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 7f311c4eb7c37..5c2d7545cfc16 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -579,7 +579,7 @@ def test_jax_relu( @handle_frontend_test( fn_tree="jax.nn.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 7c213994ed770..137deab287f8e 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -122,8 +122,11 @@ def test_prelu(*, dtype_and_x, slope, test_flags, backend_fw, fn_name, on_device small_abs_safety_factor=2, safety_factor_scale="log", ), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) -def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_relu6( + *, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device +): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -132,6 +135,7 @@ def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): fn_name=fn_name, on_device=on_device, x=x[0], + complex_mode=complex_mode, )