Skip to content

Commit

Permalink
feat: add support for complex dtype to relu6 (#23599)
Browse files Browse the repository at this point in the history
  • Loading branch information
mosesdaudu001 authored Sep 14, 2023
1 parent 4469bfe commit 8997af3
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 15 deletions.
13 changes: 11 additions & 2 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,23 @@ def prelu(
"""
return ivy.prelu(self._data, slope, out=out)

def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def relu6(
self,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the rectified linear unit 6 function element-wise.
Parameters
----------
self
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Expand All @@ -156,7 +165,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
"""
return ivy.relu6(self._data, out=out)
return ivy.relu6(self._data, complex_mode=complex_mode, out=out)

def logsigmoid(
self: ivy.Array,
Expand Down
10 changes: 10 additions & 0 deletions ivy/data_classes/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def static_relu6(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -351,6 +352,9 @@ def static_relu6(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -379,6 +383,7 @@ def static_relu6(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand All @@ -390,6 +395,7 @@ def relu6(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -412,6 +418,9 @@ def relu6(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -439,6 +448,7 @@ def relu6(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def logit(
return jnp.log(x / (1 - x))


def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def relu6(
x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None
) -> JaxArray:
relu6_func = jax.nn.relu6

# sets gradient at 0 and 6 to 0 instead of 0.5
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def thresholded_relu(


@_scalar_output_to_0d_array
def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def relu6(
x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None
) -> np.ndarray:
return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def thresholded_relu(
)


def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def relu6(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.relu6(x)
if paddle.is_complex(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def thresholded_relu(
return tf.cast(tf.where(x > threshold, x, 0), x.dtype)


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
def relu6(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor:
return tf.nn.relu6(x)


Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def thresholded_relu(


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
def relu6(
x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.nn.functional.relu6(x)


Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def relu(x):

@to_ivy_arrays_and_back
def relu6(x):
res = ivy.relu6(x)
res = ivy.relu6(x, complex_mode="jax")
return _type_conversion_64(res)


Expand Down
1 change: 1 addition & 0 deletions ivy/functional/frontends/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def relu(features, name=None):
return ivy.relu(features)


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, "tensorflow")
@to_ivy_arrays_and_back
def relu6(features, name=None):
return ivy.relu6(features)
Expand Down
35 changes: 34 additions & 1 deletion ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,28 @@ def thresholded_relu(
return current_backend(x).thresholded_relu(x, threshold=threshold, out=out)


def _relu6_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
return ivy.where(
ivy.logical_or(
ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0)
),
ivy.array(0, dtype=x.dtype),
ivy.where(
ivy.logical_or(
ivy.real(x) > 6, ivy.logical_and(ivy.real(x) == 6, ivy.imag(x) > 0)
),
ivy.array(6, dtype=x.dtype),
x,
),
)


@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand All @@ -222,8 +244,13 @@ def thresholded_relu(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def relu6(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the rectified linear unit 6 function element-wise.
Expand All @@ -232,6 +259,9 @@ def relu6(
----------
x
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down Expand Up @@ -260,6 +290,9 @@ def relu6(
return current_backend(x).relu6(x, out=out)


relu6.jax_like = _relu6_jax_like


@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand Down
15 changes: 12 additions & 3 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,17 @@ def _forward(self, x):


class ReLU6(Module):
def __init__(self):
"""Apply the RELU6 activation function."""
def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"):
"""
Apply the TANH activation function.
Parameters
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""
self._complex_mode = complex_mode
Module.__init__(self)

def _forward(self, x):
Expand All @@ -372,7 +381,7 @@ def _forward(self, x):
ret
The outputs following the RELU6 activation *[batch_shape, d]*
"""
return ivy.relu6(x)
return ivy.relu6(x, complex_mode=self._complex_mode)


class Hardswish(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def test_jax_relu(
@handle_frontend_test(
fn_tree="jax.nn.relu6",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float_and_integer"),
available_dtypes=helpers.get_dtypes("numeric"),
large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def test_prelu(*, dtype_and_x, slope, test_flags, backend_fw, fn_name, on_device
small_abs_safety_factor=2,
safety_factor_scale="log",
),
complex_mode=st.sampled_from(["jax", "split", "magnitude"]),
)
def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
def test_relu6(
*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device
):
dtype, x = dtype_and_x
helpers.test_function(
input_dtypes=dtype,
Expand All @@ -132,6 +135,7 @@ def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
fn_name=fn_name,
on_device=on_device,
x=x[0],
complex_mode=complex_mode,
)


Expand Down

0 comments on commit 8997af3

Please sign in to comment.