Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation complex dtype relu6 #23599

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,23 @@ def prelu(
"""
return ivy.prelu(self._data, slope, out=out)

def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def relu6(
self,
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the rectified linear unit 6 function element-wise.

Parameters
----------
self
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Expand All @@ -156,7 +165,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
"""
return ivy.relu6(self._data, out=out)
return ivy.relu6(self._data, complex_mode=complex_mode, out=out)

def logsigmoid(
self: ivy.Array,
Expand Down
10 changes: 10 additions & 0 deletions ivy/data_classes/container/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def static_relu6(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -351,6 +352,9 @@ def static_relu6(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -379,6 +383,7 @@ def static_relu6(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand All @@ -390,6 +395,7 @@ def relu6(
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
Expand All @@ -412,6 +418,9 @@ def relu6(
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -439,6 +448,7 @@ def relu6(
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
complex_mode=complex_mode,
out=out,
)

Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/jax/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def logit(
return jnp.log(x / (1 - x))


def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def relu6(
x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None
) -> JaxArray:
relu6_func = jax.nn.relu6

# sets gradient at 0 and 6 to 0 instead of 0.5
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/numpy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def thresholded_relu(


@_scalar_output_to_0d_array
def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def relu6(
x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None
) -> np.ndarray:
return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/paddle/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def thresholded_relu(
)


def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def relu6(
x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None
) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.relu6(x)
if paddle.is_complex(x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def thresholded_relu(
return tf.cast(tf.where(x > threshold, x, 0), x.dtype)


@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
# @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line should probably be removed completely rather than just commented out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def relu6(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor:
return tf.nn.relu6(x)


Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/torch/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def thresholded_relu(


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version)
def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
def relu6(
x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.nn.functional.relu6(x)


Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def relu(x):

@to_ivy_arrays_and_back
def relu6(x):
res = ivy.relu6(x)
res = ivy.relu6(x, complex_mode="jax")
return _type_conversion_64(res)


Expand Down
35 changes: 34 additions & 1 deletion ivy/functional/ivy/experimental/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,28 @@ def thresholded_relu(
return current_backend(x).thresholded_relu(x, threshold=threshold, out=out)


def _relu6_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
return ivy.where(
ivy.logical_or(
ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0)
),
ivy.array(0, dtype=x.dtype),
ivy.where(
ivy.logical_or(
ivy.real(x) > 6, ivy.logical_and(ivy.real(x) == 6, ivy.imag(x) > 0)
),
ivy.array(6, dtype=x.dtype),
x,
),
)


@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand All @@ -222,8 +244,13 @@ def thresholded_relu(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def relu6(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Apply the rectified linear unit 6 function element-wise.
Expand All @@ -232,6 +259,9 @@ def relu6(
----------
x
input array
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down Expand Up @@ -260,6 +290,9 @@ def relu6(
return current_backend(x).relu6(x, out=out)


relu6.jax_like = _relu6_jax_like


@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand Down
15 changes: 12 additions & 3 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,17 @@ def _forward(self, x):


class ReLU6(Module):
def __init__(self):
"""Apply the RELU6 activation function."""
def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"):
"""
Apply the TANH activation function.

Parameters
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""
self._complex_mode = complex_mode
Module.__init__(self)

def _forward(self, x):
Expand All @@ -372,7 +381,7 @@ def _forward(self, x):
ret
The outputs following the RELU6 activation *[batch_shape, d]*
"""
return ivy.relu6(x)
return ivy.relu6(x, complex_mode=self._complex_mode)


class Hardswish(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def test_jax_relu(
@handle_frontend_test(
fn_tree="jax.nn.relu6",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float_and_integer"),
available_dtypes=helpers.get_dtypes("numeric"),
large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,11 @@ def test_prelu(*, dtype_and_x, slope, test_flags, backend_fw, fn_name, on_device
small_abs_safety_factor=2,
safety_factor_scale="log",
),
complex_mode=st.sampled_from(["jax", "split", "magnitude"]),
)
def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
def test_relu6(
*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device
):
dtype, x = dtype_and_x
helpers.test_function(
input_dtypes=dtype,
Expand All @@ -132,6 +135,7 @@ def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device):
fn_name=fn_name,
on_device=on_device,
x=x[0],
complex_mode=complex_mode,
)


Expand Down
Loading